akhaliq HF staff commited on
Commit
0771f1b
1 Parent(s): fd8772c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import numpy as np
3
+ import onnxruntime
4
+ from torch import nn
5
+ import torch.utils.model_zoo as model_zoo
6
+ import torch.onnx
7
+ import torch.nn as nn
8
+ import torch.nn.init as init
9
+ import matplotlib.pyplot as plt
10
+ import json
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from resizeimage import resizeimage
13
+ import numpy as np
14
+ import pdb
15
+ import onnx
16
+
17
+ class SuperResolutionNet(nn.Module):
18
+ def __init__(self, upscale_factor, inplace=False):
19
+ super(SuperResolutionNet, self).__init__()
20
+
21
+ self.relu = nn.ReLU(inplace=inplace)
22
+ self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
23
+ self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
24
+ self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
25
+ self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
26
+ self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
27
+
28
+ self._initialize_weights()
29
+
30
+ def forward(self, x):
31
+ x = self.relu(self.conv1(x))
32
+ x = self.relu(self.conv2(x))
33
+ x = self.relu(self.conv3(x))
34
+ x = self.pixel_shuffle(self.conv4(x))
35
+ return x
36
+
37
+ def _initialize_weights(self):
38
+ init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
39
+ init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
40
+ init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
41
+ init.orthogonal_(self.conv4.weight)
42
+
43
+ # Create the super-resolution model by using the above model definition.
44
+ torch_model = SuperResolutionNet(upscale_factor=3)
45
+
46
+ model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
47
+ batch_size = 1 # just a random number
48
+
49
+ # Initialize model with the pretrained weights
50
+ map_location = lambda storage, loc: storage
51
+ if torch.cuda.is_available():
52
+ map_location = None
53
+ torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
54
+
55
+
56
+
57
+ x = torch.randn(1, 1, 224, 224, requires_grad=True)
58
+ torch_model.eval()
59
+
60
+
61
+
62
+ os.system("wget https://github.com/AK391/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx")
63
+
64
+ # Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
65
+ # other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
66
+ # based on the build flags) when instantiating InferenceSession.
67
+ # For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
68
+ # onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
69
+ ort_session = onnxruntime.InferenceSession("super-resolution-10.onnx")
70
+ ort_inputs = {ort_session.get_inputs()[0].name: img_5}
71
+ ort_outs = ort_session.run(None, ort_inputs)
72
+ img_out_y = ort_outs[0]
73
+
74
+ def inference(img):
75
+ orig_img = Image.open(img)
76
+ img = resizeimage.resize_cover(orig_img, [224,224], validate=False)
77
+ img_ycbcr = img.convert('YCbCr')
78
+ img_y_0, img_cb, img_cr = img_ycbcr.split()
79
+ img_ndarray = np.asarray(img_y_0)
80
+
81
+ img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0)
82
+ img_5 = img_4.astype(np.float32) / 255.0
83
+ img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
84
+ final_img = Image.merge(
85
+ "YCbCr", [
86
+ img_out_y,
87
+ img_cb.resize(img_out_y.size, Image.BICUBIC),
88
+ img_cr.resize(img_out_y.size, Image.BICUBIC),
89
+ ]).convert("RGB")
90
+ return final_image
91
+
92
+ gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil")).launch()