File size: 3,593 Bytes
0771f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f76e2
da84e69
0771f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import io
import numpy as np
import onnxruntime
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init
import matplotlib.pyplot as plt
import json
from PIL import Image, ImageDraw, ImageFont
from resizeimage import resizeimage
import numpy as np
import pdb
import onnx
import gradio as gr
import os

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))



x = torch.randn(1, 1, 224, 224, requires_grad=True)
torch_model.eval()



os.system("wget https://github.com/AK391/models/raw/main/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx")

# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# For example, if NVIDIA GPU is available and ORT Python package is built with CUDA, then call API as following:
# onnxruntime.InferenceSession(path/to/model, providers=['CUDAExecutionProvider'])
ort_session = onnxruntime.InferenceSession("super-resolution-10.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: img_5} 
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

def inference(img):
  orig_img = Image.open(img)
  img = resizeimage.resize_cover(orig_img, [224,224], validate=False)
  img_ycbcr = img.convert('YCbCr')
  img_y_0, img_cb, img_cr = img_ycbcr.split()
  img_ndarray = np.asarray(img_y_0)
  
  img_4 = np.expand_dims(np.expand_dims(img_ndarray, axis=0), axis=0)
  img_5 = img_4.astype(np.float32) / 255.0
  img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')
  final_img = Image.merge(
      "YCbCr", [
          img_out_y,
          img_cb.resize(img_out_y.size, Image.BICUBIC),
          img_cr.resize(img_out_y.size, Image.BICUBIC),
      ]).convert("RGB")
  return final_image

gr.Interface(inference,gr.inputs.Image(type="filepath"),gr.outputs.Image(type="pil")).launch()