Spaces:
Running
Running
import os | |
import cv2 | |
import sys | |
import torch | |
import numpy as np | |
import os.path as osp | |
from PIL import Image | |
from basicsr.utils import img2tensor | |
from basicsr.archs.srresnet_arch import MSRResNet | |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) | |
sys.path.append(root_path) | |
class SRResNet: | |
def __init__(self, upscale=2, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16): | |
self.upscale = int(upscale) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ------------------ load model for img enhancement ------------------- | |
self.sr_model = MSRResNet( | |
upscale=self.upscale, | |
num_in_ch=num_in_ch, | |
num_out_ch=num_out_ch, | |
num_feat=num_feat, | |
num_block=num_block, | |
).to(self.device) | |
ckpt_path = os.path.join( | |
ROOT_DIR, | |
"SR_Inference", | |
"srresnet", | |
"weights", | |
f"SRResNet_{str(self.upscale)}x.pth", | |
) | |
loadnet = torch.load(ckpt_path, map_location=self.device) | |
if "params_ema" in loadnet: | |
keyname = "params_ema" | |
else: | |
keyname = "params" | |
self.sr_model.load_state_dict(loadnet[keyname]) | |
self.sr_model.eval() | |
def __call__(self, img): | |
img_tensor = ( | |
img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True) | |
.unsqueeze(0) | |
.to(self.device) | |
) | |
restored_img = self.sr_model(img_tensor)[0] | |
restored_img = restored_img.permute(1, 2, 0).cpu().numpy() | |
restored_img = (restored_img - restored_img.min()) / ( | |
restored_img.max() - restored_img.min() | |
) | |
restored_img = (restored_img * 255).astype(np.uint8) | |
restored_img = Image.fromarray(restored_img) | |
restored_img = np.array(restored_img) | |
sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR) | |
return sr_img | |
if __name__ == "__main__": | |
srresnet = SRResNet(upscale=2) | |
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png") | |
sr_img = srresnet(img=img) | |
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs" | |
os.makedirs(saving_dir, exist_ok=True) | |
cv2.imwrite(f"{saving_dir}/sr_img_srresnet.png", sr_img) | |