Meloo commited on
Commit
6265765
1 Parent(s): 815a758

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -116,23 +116,24 @@ def load_img(filename, norm=True):
116
 
117
 
118
  def inference(image, upscale, large_input_flag, color_fix):
119
- if upscale is None or not isinstance(upscale, (int, float)):
120
- upscale = 2
121
-
122
- upscale = int(upscale)
123
- if 0 < upscale < 3:
124
  upscale = 2
 
 
125
 
126
  model = set_safmn(upscale)
127
 
128
- img = np.array(image)
 
 
 
129
  img = img.astype(np.float32) / 255.
130
- y = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
131
- y = y.unsqueeze(0).to(device)
132
 
133
  # inference
134
  if large_input_flag:
135
- patches, idx, size = img2patch(y, scale=upscale)
136
  with torch.no_grad():
137
  n = len(patches)
138
  outs = []
@@ -152,19 +153,24 @@ def inference(image, upscale, large_input_flag, color_fix):
152
  output = patch2img(output, idx, size, scale=upscale)
153
  else:
154
  with torch.no_grad():
155
- output = model(y)
156
 
157
  # color fix
158
  if color_fix:
159
- y = F.interpolate(img, scale_factor=upscale, mode='bilinear')
160
- output = wavelet_reconstruction(output, y)
161
  # tensor2img
162
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
163
  if output.ndim == 3:
164
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
165
  output = (output * 255.0).round().astype(np.uint8)
166
 
167
- return (image, output)
 
 
 
 
 
168
 
169
 
170
 
 
116
 
117
 
118
  def inference(image, upscale, large_input_flag, color_fix):
119
+ if upscale is None or not isinstance(upscale, (int, float)) or upscale == 3:
 
 
 
 
120
  upscale = 2
121
+
122
+ upscale = int(upscale) # convert type to int
123
 
124
  model = set_safmn(upscale)
125
 
126
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
127
+ print(f'input size: {img.shape}')
128
+
129
+ # img2tensor
130
  img = img.astype(np.float32) / 255.
131
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
132
+ img = img.unsqueeze(0).to(device)
133
 
134
  # inference
135
  if large_input_flag:
136
+ patches, idx, size = img2patch(img, scale=upscale)
137
  with torch.no_grad():
138
  n = len(patches)
139
  outs = []
 
153
  output = patch2img(output, idx, size, scale=upscale)
154
  else:
155
  with torch.no_grad():
156
+ output = model(img)
157
 
158
  # color fix
159
  if color_fix:
160
+ img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
161
+ output = wavelet_reconstruction(output, img)
162
  # tensor2img
163
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
164
  if output.ndim == 3:
165
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
166
  output = (output * 255.0).round().astype(np.uint8)
167
 
168
+ # save restored img
169
+ save_path = f'results/out.png'
170
+ cv2.imwrite(save_path, output)
171
+
172
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
173
+ return output, save_path
174
 
175
 
176