Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
from torch import nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
import torch.onnx
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.init as init
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import json
|
11 |
+
from PIL import Image, ImageDraw, ImageFont
|
12 |
+
from resizeimage import resizeimage
|
13 |
+
import numpy as np
|
14 |
+
import pdb
|
15 |
+
import onnx
|
16 |
+
|
17 |
+
class SuperResolutionNet(nn.Module):
|
18 |
+
def __init__(self, upscale_factor, inplace=False):
|
19 |
+
super(SuperResolutionNet, self).__init__()
|
20 |
+
|
21 |
+
self.relu = nn.ReLU(inplace=inplace)
|
22 |
+
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
23 |
+
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
24 |
+
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
25 |
+
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
26 |
+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
27 |
+
|
28 |
+
self._initialize_weights()
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.relu(self.conv1(x))
|
32 |
+
x = self.relu(self.conv2(x))
|
33 |
+
x = self.relu(self.conv3(x))
|
34 |
+
x = self.pixel_shuffle(self.conv4(x))
|
35 |
+
return x
|
36 |
+
|
37 |
+
def _initialize_weights(self):
|
38 |
+
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
|
39 |
+
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
|
40 |
+
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
|
41 |
+
init.orthogonal_(self.conv4.weight)
|
42 |
+
|
43 |
+
# Create the super-resolution model by using the above model definition.
|
44 |
+
torch_model = SuperResolutionNet(upscale_factor=3)
|
45 |
+
|
46 |
+
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
|
47 |
+
batch_size = 1 # just a random number
|
48 |
+
|
49 |
+
# Initialize model with the pretrained weights
|
50 |
+
map_location = lambda storage, loc: storage
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
map_location = None
|
53 |
+
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
x = torch.randn(1, 1, 224, 224, requires_grad=True)
|
58 |
+
torch_model.eval()
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
os.system("wget https://github.com/AK391/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx")
|
63 |
+
|
64 |
+
# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
|
65 |
+
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
|
66 |
+
# based on the build flags) when instantiating InferenceSession.
|
67 |
+
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
|
68 |
+
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
|
69 |
+
ort_session = onnxruntime.InferenceSession("super-resolution-10.onnx")
|
70 |
+
ort_inputs = {ort_session.get_inputs()[0].name: img_5}
|
71 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
72 |
+
img_out_y = ort_outs[0]
|
73 |
+
|
74 |
+
def inference(img):
|
75 |
+
orig_img = Image.open(img)
|
76 |
+
img = resizeimage.resize_cover(orig_img, [224,224], validate=False)
|
77 |
+
img_ycbcr = img.convert('YCbCr')
|
78 |
+
img_y_0, img_cb, img_cr = img_ycbcr.split()
|
79 |
+
img_ndarray = np.asarray(img_y_0)
|
80 |
+
|
81 |
+
img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0)
|
82 |
+
img_5 = img_4.astype(np.float32) / 255.0
|
83 |
+
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
|
84 |
+
final_img = Image.merge(
|
85 |
+
"YCbCr", [
|
86 |
+
img_out_y,
|
87 |
+
img_cb.resize(img_out_y.size, Image.BICUBIC),
|
88 |
+
img_cr.resize(img_out_y.size, Image.BICUBIC),
|
89 |
+
]).convert("RGB")
|
90 |
+
return final_image
|
91 |
+
|
92 |
+
gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil")).launch()
|