Update app.py
Browse files
app.py
CHANGED
@@ -116,23 +116,24 @@ def load_img(filename, norm=True):
|
|
116 |
|
117 |
|
118 |
def inference(image, upscale, large_input_flag, color_fix):
|
119 |
-
|
120 |
-
upscale = 2
|
121 |
-
|
122 |
-
upscale = int(upscale)
|
123 |
-
if 0 < upscale < 3:
|
124 |
upscale = 2
|
|
|
|
|
125 |
|
126 |
model = set_safmn(upscale)
|
127 |
|
128 |
-
img =
|
|
|
|
|
|
|
129 |
img = img.astype(np.float32) / 255.
|
130 |
-
|
131 |
-
|
132 |
|
133 |
# inference
|
134 |
if large_input_flag:
|
135 |
-
patches, idx, size = img2patch(
|
136 |
with torch.no_grad():
|
137 |
n = len(patches)
|
138 |
outs = []
|
@@ -152,19 +153,24 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
152 |
output = patch2img(output, idx, size, scale=upscale)
|
153 |
else:
|
154 |
with torch.no_grad():
|
155 |
-
output = model(
|
156 |
|
157 |
# color fix
|
158 |
if color_fix:
|
159 |
-
|
160 |
-
output = wavelet_reconstruction(output,
|
161 |
# tensor2img
|
162 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
163 |
if output.ndim == 3:
|
164 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
165 |
output = (output * 255.0).round().astype(np.uint8)
|
166 |
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
|
170 |
|
|
|
116 |
|
117 |
|
118 |
def inference(image, upscale, large_input_flag, color_fix):
|
119 |
+
if upscale is None or not isinstance(upscale, (int, float)) or upscale == 3:
|
|
|
|
|
|
|
|
|
120 |
upscale = 2
|
121 |
+
|
122 |
+
upscale = int(upscale) # convert type to int
|
123 |
|
124 |
model = set_safmn(upscale)
|
125 |
|
126 |
+
img = cv2.imread(str(image), cv2.IMREAD_COLOR)
|
127 |
+
print(f'input size: {img.shape}')
|
128 |
+
|
129 |
+
# img2tensor
|
130 |
img = img.astype(np.float32) / 255.
|
131 |
+
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
132 |
+
img = img.unsqueeze(0).to(device)
|
133 |
|
134 |
# inference
|
135 |
if large_input_flag:
|
136 |
+
patches, idx, size = img2patch(img, scale=upscale)
|
137 |
with torch.no_grad():
|
138 |
n = len(patches)
|
139 |
outs = []
|
|
|
153 |
output = patch2img(output, idx, size, scale=upscale)
|
154 |
else:
|
155 |
with torch.no_grad():
|
156 |
+
output = model(img)
|
157 |
|
158 |
# color fix
|
159 |
if color_fix:
|
160 |
+
img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
|
161 |
+
output = wavelet_reconstruction(output, img)
|
162 |
# tensor2img
|
163 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
164 |
if output.ndim == 3:
|
165 |
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
166 |
output = (output * 255.0).round().astype(np.uint8)
|
167 |
|
168 |
+
# save restored img
|
169 |
+
save_path = f'results/out.png'
|
170 |
+
cv2.imwrite(save_path, output)
|
171 |
+
|
172 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
173 |
+
return output, save_path
|
174 |
|
175 |
|
176 |
|