Spaces:
Running
Running
import os | |
import torch | |
from tqdm import tqdm | |
from src.utils.videoio import load_video_to_cv2 | |
import cv2 | |
class GeneratorWithLen(object): | |
""" From https://stackoverflow.com/a/7460929 """ | |
def __init__(self, gen, length): | |
self.gen = gen | |
self.length = length | |
def __len__(self): | |
return self.length | |
def __iter__(self): | |
return self.gen | |
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): | |
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) | |
return list(gen) | |
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): | |
""" Provide a generator with a __len__ method so that it can passed to functions that | |
call len()""" | |
if os.path.isfile(images): # handle video to images | |
# TODO: Create a generator version of load_video_to_cv2 | |
images = load_video_to_cv2(images) | |
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) | |
gen_with_len = GeneratorWithLen(gen, len(images)) | |
return gen_with_len | |
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): | |
""" Provide a generator function so that all of the enhanced images don't need | |
to be stored in memory at the same time. This can save tons of RAM compared to | |
the enhancer function. """ | |
try: | |
from gfpgan import GFPGANer | |
except ImportError: | |
print("GFPGAN library not found. Installing...") | |
try: | |
# Use pip to install the library | |
import subprocess | |
subprocess.check_call(["pip", "install", "gfpgan"]) | |
# Retry the import after installation | |
from gfpgan import GFPGANer | |
print("GFPGAN library installed successfully!") | |
except Exception as e: | |
print(f"Failed to install GFPGAN library. Error: {e}") | |
# Handle the error or raise it again if needed | |
print('face enhancer....') | |
if not isinstance(images, list) and os.path.isfile(images): # handle video to images | |
images = load_video_to_cv2(images) | |
# ------------------------ set up GFPGAN restorer ------------------------ | |
if method == 'gfpgan': | |
arch = 'clean' | |
channel_multiplier = 2 | |
model_name = 'GFPGANv1.4' | |
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' | |
elif method == 'RestoreFormer': | |
arch = 'RestoreFormer' | |
channel_multiplier = 2 | |
model_name = 'RestoreFormer' | |
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' | |
elif method == 'codeformer': # TODO: | |
arch = 'CodeFormer' | |
channel_multiplier = 2 | |
model_name = 'CodeFormer' | |
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' | |
else: | |
raise ValueError(f'Wrong model version {method}.') | |
# ------------------------ set up background upsampler ------------------------ | |
if bg_upsampler == 'realesrgan': | |
if not torch.cuda.is_available(): # CPU | |
import warnings | |
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' | |
'If you really want to use it, please modify the corresponding codes.') | |
bg_upsampler = None | |
else: | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
from realesrgan import RealESRGANer | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
bg_upsampler = RealESRGANer( | |
scale=2, | |
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', | |
model=model, | |
tile=400, | |
tile_pad=10, | |
pre_pad=0, | |
half=True) # need to set False in CPU mode | |
else: | |
bg_upsampler = None | |
# determine model paths | |
model_path = os.path.join('gfpgan/weights', model_name + '.pth') | |
if not os.path.isfile(model_path): | |
model_path = os.path.join('checkpoints', model_name + '.pth') | |
if not os.path.isfile(model_path): | |
# download pre-trained models from url | |
model_path = url | |
restorer = GFPGANer( | |
model_path=model_path, | |
upscale=2, | |
arch=arch, | |
channel_multiplier=channel_multiplier, | |
bg_upsampler=bg_upsampler) | |
# ------------------------ restore ------------------------ | |
for idx in tqdm(range(len(images)), 'Face Enhancer:'): | |
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) | |
# restore faces and background if necessary | |
cropped_faces, restored_faces, r_img = restorer.enhance( | |
img, | |
has_aligned=False, | |
only_center_face=False, | |
paste_back=True) | |
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) | |
yield r_img | |