pupilsense / SR_Inference /inference_srresnet.py
vijul.shah
End-to-End Pipeline Configured
0f2d9f6
raw
history blame
2.39 kB
import os
import cv2
import sys
import torch
import numpy as np
import os.path as osp
from PIL import Image
from basicsr.utils import img2tensor
from basicsr.archs.srresnet_arch import MSRResNet
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
sys.path.append(root_path)
class SRResNet:
def __init__(self, upscale=2, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16):
self.upscale = int(upscale)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------ load model for img enhancement -------------------
self.sr_model = MSRResNet(
upscale=self.upscale,
num_in_ch=num_in_ch,
num_out_ch=num_out_ch,
num_feat=num_feat,
num_block=num_block,
).to(self.device)
ckpt_path = os.path.join(
ROOT_DIR,
"SR_Inference",
"srresnet",
"weights",
f"SRResNet_{str(self.upscale)}x.pth",
)
loadnet = torch.load(ckpt_path, map_location=self.device)
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
self.sr_model.load_state_dict(loadnet[keyname])
self.sr_model.eval()
@torch.no_grad()
def __call__(self, img):
img_tensor = (
img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
.unsqueeze(0)
.to(self.device)
)
restored_img = self.sr_model(img_tensor)[0]
restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
restored_img = (restored_img - restored_img.min()) / (
restored_img.max() - restored_img.min()
)
restored_img = (restored_img * 255).astype(np.uint8)
restored_img = Image.fromarray(restored_img)
restored_img = np.array(restored_img)
sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
return sr_img
if __name__ == "__main__":
srresnet = SRResNet(upscale=2)
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
sr_img = srresnet(img=img)
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
os.makedirs(saving_dir, exist_ok=True)
cv2.imwrite(f"{saving_dir}/sr_img_srresnet.png", sr_img)