Update app.py
Browse files
app.py
CHANGED
@@ -22,10 +22,10 @@ pretrain_model_url = {
|
|
22 |
|
23 |
|
24 |
# download weights
|
25 |
-
if not os.path.exists('
|
26 |
load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./pretrained_models/', progress=True, file_name=None)
|
27 |
|
28 |
-
if not os.path.exists('
|
29 |
load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./pretrained_models/', progress=True, file_name=None)
|
30 |
|
31 |
|
@@ -34,9 +34,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
34 |
def set_safmn(upscale):
|
35 |
model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale)
|
36 |
if upscale == 2:
|
37 |
-
model_path = '
|
38 |
elif upscale == 4:
|
39 |
-
model_path = '
|
40 |
else:
|
41 |
raise NotImplementedError('Only support x2/x4 upscaling!')
|
42 |
|
@@ -104,17 +104,6 @@ def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
|
|
104 |
return (preds / count_mt).to(outs.device)
|
105 |
|
106 |
|
107 |
-
def load_img(filename, norm=True):
|
108 |
-
img = np.array(Image.open(filename).convert("RGB"))
|
109 |
-
h, w = img.shape[:2]
|
110 |
-
|
111 |
-
if norm:
|
112 |
-
img = img.astype(np.float32) / 255.
|
113 |
-
|
114 |
-
return img
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
def inference(image, upscale, large_input_flag, color_fix):
|
119 |
if upscale is None or not isinstance(upscale, (int, float)) or upscale == 3.:
|
120 |
upscale = 2
|
|
|
22 |
|
23 |
|
24 |
# download weights
|
25 |
+
if not os.path.exists('pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'):
|
26 |
load_file_from_url(url=pretrain_model_url['safmn_x2'], model_dir='./pretrained_models/', progress=True, file_name=None)
|
27 |
|
28 |
+
if not os.path.exists('pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'):
|
29 |
load_file_from_url(url=pretrain_model_url['safmn_x4'], model_dir='./pretrained_models/', progress=True, file_name=None)
|
30 |
|
31 |
|
|
|
34 |
def set_safmn(upscale):
|
35 |
model = SAFMN(dim=128, n_blocks=16, ffn_scale=2.0, upscaling_factor=upscale)
|
36 |
if upscale == 2:
|
37 |
+
model_path = 'pretrained_models/SAFMN_L_Real_LSDIR_x2-v2.pth'
|
38 |
elif upscale == 4:
|
39 |
+
model_path = 'pretrained_models/SAFMN_L_Real_LSDIR_x4-v2.pth'
|
40 |
else:
|
41 |
raise NotImplementedError('Only support x2/x4 upscaling!')
|
42 |
|
|
|
104 |
return (preds / count_mt).to(outs.device)
|
105 |
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def inference(image, upscale, large_input_flag, color_fix):
|
108 |
if upscale is None or not isinstance(upscale, (int, float)) or upscale == 3.:
|
109 |
upscale = 2
|