Update app.py
Browse files
app.py
CHANGED
@@ -114,8 +114,8 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
114 |
|
115 |
# img2tensor
|
116 |
y = np.array(image).astype(np.float32) / 255.
|
117 |
-
|
118 |
-
y = torch.from_numpy(np.transpose(y, (2, 0, 1))).float()
|
119 |
y = y.unsqueeze(0).to(device)
|
120 |
|
121 |
# inference
|
@@ -148,8 +148,8 @@ def inference(image, upscale, large_input_flag, color_fix):
|
|
148 |
output = wavelet_reconstruction(output, y)
|
149 |
# tensor2img
|
150 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
151 |
-
|
152 |
-
|
153 |
output = (output * 255.0).round().astype(np.uint8)
|
154 |
|
155 |
# # save results
|
|
|
114 |
|
115 |
# img2tensor
|
116 |
y = np.array(image).astype(np.float32) / 255.
|
117 |
+
y = torch.from_numpy(np.transpose(y[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
118 |
+
# y = torch.from_numpy(np.transpose(y, (2, 0, 1))).float()
|
119 |
y = y.unsqueeze(0).to(device)
|
120 |
|
121 |
# inference
|
|
|
148 |
output = wavelet_reconstruction(output, y)
|
149 |
# tensor2img
|
150 |
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
151 |
+
if output.ndim == 3:
|
152 |
+
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
|
153 |
output = (output * 255.0).round().astype(np.uint8)
|
154 |
|
155 |
# # save results
|