import os import cv2 import sys import torch import os.path as osp from gfpgan import GFPGANer from basicsr.utils.download_util import load_file_from_url ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) sys.path.append(root_path) from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo class GFPGAN: def __init__( self, upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet", ): upscale = int(upscale) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------ set up background upsampler ------------------------ upsampler_zoo = RealEsrUpsamplerZoo( upscale=upscale, bg_upsampler_name=bg_upsampler_name, prefered_net_in_upsampler=prefered_net_in_upsampler, ) bg_upsampler = upsampler_zoo.bg_upsampler # ------------------------ load model ------------------------ gfpgan_weights_path = os.path.join( ROOT_DIR, "SR_Inference", "gfpgan", "weights" ) gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth") if not os.path.isfile(gfpgan_model_path): url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth" gfpgan_model_path = load_file_from_url( url=url, model_dir=gfpgan_weights_path, progress=True, file_name="GFPGANv1.3.pth", ) self.sr_model = GFPGANer( upscale=upscale, bg_upsampler=bg_upsampler, model_path=gfpgan_model_path, device=device, ) def __call__(self, img): # ------------------------ restore/enhance image using GFPGAN model ------------------------ cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img) return sr_img if __name__ == "__main__": gfpgan = GFPGAN( upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet" ) img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png") sr_img = gfpgan(img=img) saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs" os.makedirs(saving_dir, exist_ok=True) cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)