Meloo commited on
Commit
d7840c4
·
verified ·
1 Parent(s): 20d7a63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -104,7 +104,7 @@ def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
104
  return (preds / count_mt).to(outs.device)
105
 
106
 
107
- def load_img (filename, norm=True):
108
  img = np.array(Image.open(filename).convert("RGB"))
109
  h, w = img.shape[:2]
110
 
@@ -114,20 +114,21 @@ def load_img (filename, norm=True):
114
  return img
115
 
116
 
117
- def inference(image, upscale, large_input_flag, color_fix):
118
- model = set_safmn(upscale)
119
-
120
- img = np.array(image).astype(np.float32) / 255.
121
-
122
- y = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
123
- y = y.unsqueeze(0).to(device)
124
 
 
125
  upscale = int(upscale) # convert type to int
126
  if upscale > 4:
127
  upscale = 4
128
  if 0 < upscale < 3:
129
  upscale = 2
130
 
 
 
 
 
 
 
 
131
  # inference
132
  if large_input_flag:
133
  patches, idx, size = img2patch(y, scale=upscale)
@@ -154,7 +155,7 @@ def inference(image, upscale, large_input_flag, color_fix):
154
 
155
  # color fix
156
  if color_fix:
157
- y = F.interpolate(y, scale_factor=upscale, mode='bilinear')
158
  output = wavelet_reconstruction(output, y)
159
  # tensor2img
160
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -162,7 +163,8 @@ def inference(image, upscale, large_input_flag, color_fix):
162
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
163
  output = (output * 255.0).round().astype(np.uint8)
164
 
165
- return (image, Image.fromarray(output))
 
166
 
167
 
168
 
 
104
  return (preds / count_mt).to(outs.device)
105
 
106
 
107
+ def load_img(filename, norm=True):
108
  img = np.array(Image.open(filename).convert("RGB"))
109
  h, w = img.shape[:2]
110
 
 
114
  return img
115
 
116
 
 
 
 
 
 
 
 
117
 
118
+ def inference(image, upscale, large_input_flag, color_fix):
119
  upscale = int(upscale) # convert type to int
120
  if upscale > 4:
121
  upscale = 4
122
  if 0 < upscale < 3:
123
  upscale = 2
124
 
125
+ model = set_safmn(upscale)
126
+
127
+ img = np.array(image)
128
+ img = img.astype(np.float32) / 255.
129
+ y = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
130
+ y = y.unsqueeze(0).to(device)
131
+
132
  # inference
133
  if large_input_flag:
134
  patches, idx, size = img2patch(y, scale=upscale)
 
155
 
156
  # color fix
157
  if color_fix:
158
+ y = F.interpolate(img, scale_factor=upscale, mode='bilinear')
159
  output = wavelet_reconstruction(output, y)
160
  # tensor2img
161
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
 
163
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
164
  output = (output * 255.0).round().astype(np.uint8)
165
 
166
+ return (image, output)
167
+
168
 
169
 
170