Meloo commited on
Commit
1b5a239
·
verified ·
1 Parent(s): ba0771c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -33
app.py CHANGED
@@ -106,7 +106,15 @@ def patch2img(outs, idxes, sr_size, scale=4, crop_size=512):
106
  return (preds / count_mt).to(outs.device)
107
 
108
 
109
- os.makedirs('./results', exist_ok=True)
 
 
 
 
 
 
 
 
110
 
111
  def inference(image, upscale, large_input_flag, color_fix):
112
  upscale = int(upscale) # convert type to int
@@ -117,17 +125,16 @@ def inference(image, upscale, large_input_flag, color_fix):
117
 
118
  model = set_safmn(upscale)
119
 
120
- img = cv2.imread(str(image), cv2.IMREAD_COLOR)
121
- print(f'input size: {img.shape}')
 
122
 
123
- # img2tensor
124
- img = img.astype(np.float32) / 255.
125
- img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
126
- img = img.unsqueeze(0).to(device)
127
 
128
  # inference
129
  if large_input_flag:
130
- patches, idx, size = img2patch(img, scale=upscale)
131
  with torch.no_grad():
132
  n = len(patches)
133
  outs = []
@@ -147,24 +154,19 @@ def inference(image, upscale, large_input_flag, color_fix):
147
  output = patch2img(output, idx, size, scale=upscale)
148
  else:
149
  with torch.no_grad():
150
- output = model(img)
151
 
152
  # color fix
153
  if color_fix:
154
- img = F.interpolate(img, scale_factor=upscale, mode='bilinear')
155
- output = wavelet_reconstruction(output, img)
156
  # tensor2img
157
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
158
  if output.ndim == 3:
159
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
160
  output = (output * 255.0).round().astype(np.uint8)
161
 
162
- # save restored img
163
- save_path = f'results/out.png'
164
- cv2.imwrite(save_path, output)
165
-
166
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
167
- return output, save_path
168
 
169
 
170
 
@@ -194,21 +196,16 @@ article = "<p style='text-align: center'><a href='https://eduardzamfir.github.io
194
 
195
  #### Image,Prompts examples
196
  examples = [
197
- ['images/0801x4.png'],
198
- ['images/0840x4.png'],
199
- ['images/0841x4.png'],
200
- ['images/0870x4.png'],
201
- ['images/0878x4.png'],
202
- ['images/0884x4.png'],
203
- ['images/0900x4.png'],
204
- ['images/img002x4.png'],
205
- ['images/img003x4.png'],
206
- ['images/img004x4.png'],
207
- ['images/img035x4.png'],
208
- ['images/img053x4.png'],
209
- ['images/img064x4.png'],
210
- ['images/img083x4.png'],
211
- ['images/img092x4.png'],
212
  ]
213
 
214
  css = """
@@ -220,7 +217,7 @@ css = """
220
  """
221
 
222
  demo = gr.Interface(
223
- fn=process_img,
224
  inputs=[
225
  gr.Image(type="pil", label="Input", value="real_testdata/004.png"),
226
  gr.Number(default=2, label="Upscaling factor (up to 4)"),
 
106
  return (preds / count_mt).to(outs.device)
107
 
108
 
109
+ def load_img (filename, norm=True,):
110
+ img = np.array(Image.open(filename).convert("RGB"))
111
+ h, w = img.shape[:2]
112
+
113
+ if norm:
114
+ img = img.astype(np.float32) / 255.
115
+
116
+ return img
117
+
118
 
119
  def inference(image, upscale, large_input_flag, color_fix):
120
  upscale = int(upscale) # convert type to int
 
125
 
126
  model = set_safmn(upscale)
127
 
128
+ img = np.array(image)
129
+ img = img / 255.
130
+ img = img.astype(np.float32)
131
 
132
+ # img2tensor
133
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
 
 
134
 
135
  # inference
136
  if large_input_flag:
137
+ patches, idx, size = img2patch(y, scale=upscale)
138
  with torch.no_grad():
139
  n = len(patches)
140
  outs = []
 
154
  output = patch2img(output, idx, size, scale=upscale)
155
  else:
156
  with torch.no_grad():
157
+ output = model(y)
158
 
159
  # color fix
160
  if color_fix:
161
+ y = F.interpolate(y, scale_factor=upscale, mode='bilinear')
162
+ output = wavelet_reconstruction(output, y)
163
  # tensor2img
164
  output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
165
  if output.ndim == 3:
166
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
167
  output = (output * 255.0).round().astype(np.uint8)
168
 
169
+ return (image, Image.fromarray(output))
 
 
 
 
 
170
 
171
 
172
 
 
196
 
197
  #### Image,Prompts examples
198
  examples = [
199
+ ['real_testdata/004.png'],
200
+ ['real_testdata/005.png'],
201
+ ['real_testdata/010.png'],
202
+ ['real_testdata/015.png'],
203
+ ['real_testdata/025.png'],
204
+ ['real_testdata/030.png'],
205
+ ['real_testdata/034.png'],
206
+ ['real_testdata/044.png'],
207
+ ['real_testdata/041.png'],
208
+ ['real_testdata/054.png'],
 
 
 
 
 
209
  ]
210
 
211
  css = """
 
217
  """
218
 
219
  demo = gr.Interface(
220
+ fn=inference,
221
  inputs=[
222
  gr.Image(type="pil", label="Input", value="real_testdata/004.png"),
223
  gr.Number(default=2, label="Upscaling factor (up to 4)"),