Chakshu123 commited on
Commit
d97c34e
1 Parent(s): 7730ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -2
app.py CHANGED
@@ -4,7 +4,7 @@ sys.path.insert(0, 'gradio-modified')
4
 
5
  import gradio as gr
6
  import numpy as np
7
-
8
  from PIL import Image
9
 
10
  import torch
@@ -27,7 +27,123 @@ print('Use device:', device)
27
 
28
  net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
29
 
30
- model_net = torch.load(f'weights/colorizer.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def resize_original(img: Image.Image):
 
4
 
5
  import gradio as gr
6
  import numpy as np
7
+ import torch.nn as nn
8
  from PIL import Image
9
 
10
  import torch
 
27
 
28
  net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
29
 
30
+ class BaseColor(nn.Module):
31
+ def __init__(self):
32
+ super(BaseColor, self).__init__()
33
+
34
+ self.l_cent = 50.
35
+ self.l_norm = 100.
36
+ self.ab_norm = 110.
37
+
38
+ def normalize_l(self, in_l):
39
+ return (in_l-self.l_cent)/self.l_norm
40
+
41
+ def unnormalize_l(self, in_l):
42
+ return in_l*self.l_norm + self.l_cent
43
+
44
+ def normalize_ab(self, in_ab):
45
+ return in_ab/self.ab_norm
46
+
47
+ def unnormalize_ab(self, in_ab):
48
+ return in_ab*self.ab_norm
49
+
50
+
51
+
52
+ class ECCVGenerator(BaseColor):
53
+ def __init__(self, norm_layer=nn.BatchNorm2d):
54
+ super(ECCVGenerator, self).__init__()
55
+
56
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
57
+ model1+=[nn.ReLU(True),]
58
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
59
+ model1+=[nn.ReLU(True),]
60
+ model1+=[norm_layer(64),]
61
+
62
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
63
+ model2+=[nn.ReLU(True),]
64
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
65
+ model2+=[nn.ReLU(True),]
66
+ model2+=[norm_layer(128),]
67
+
68
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model3+=[nn.ReLU(True),]
70
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
71
+ model3+=[nn.ReLU(True),]
72
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
73
+ model3+=[nn.ReLU(True),]
74
+ model3+=[norm_layer(256),]
75
+
76
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
77
+ model4+=[nn.ReLU(True),]
78
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
79
+ model4+=[nn.ReLU(True),]
80
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
81
+ model4+=[nn.ReLU(True),]
82
+ model4+=[norm_layer(512),]
83
+
84
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
85
+ model5+=[nn.ReLU(True),]
86
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
87
+ model5+=[nn.ReLU(True),]
88
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
89
+ model5+=[nn.ReLU(True),]
90
+ model5+=[norm_layer(512),]
91
+
92
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
93
+ model6+=[nn.ReLU(True),]
94
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
95
+ model6+=[nn.ReLU(True),]
96
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
97
+ model6+=[nn.ReLU(True),]
98
+ model6+=[norm_layer(512),]
99
+
100
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
101
+ model7+=[nn.ReLU(True),]
102
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
103
+ model7+=[nn.ReLU(True),]
104
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
105
+ model7+=[nn.ReLU(True),]
106
+ model7+=[norm_layer(512),]
107
+
108
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
109
+ model8+=[nn.ReLU(True),]
110
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
111
+ model8+=[nn.ReLU(True),]
112
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
113
+ model8+=[nn.ReLU(True),]
114
+
115
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
116
+
117
+ self.model1 = nn.Sequential(*model1)
118
+ self.model2 = nn.Sequential(*model2)
119
+ self.model3 = nn.Sequential(*model3)
120
+ self.model4 = nn.Sequential(*model4)
121
+ self.model5 = nn.Sequential(*model5)
122
+ self.model6 = nn.Sequential(*model6)
123
+ self.model7 = nn.Sequential(*model7)
124
+ self.model8 = nn.Sequential(*model8)
125
+
126
+ self.softmax = nn.Softmax(dim=1)
127
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
128
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
129
+
130
+ def forward(self, input_l):
131
+ conv1_2 = self.model1(self.normalize_l(input_l))
132
+ conv2_2 = self.model2(conv1_2)
133
+ conv3_3 = self.model3(conv2_2)
134
+ conv4_3 = self.model4(conv3_3)
135
+ conv5_3 = self.model5(conv4_3)
136
+ conv6_3 = self.model6(conv5_3)
137
+ conv7_3 = self.model7(conv6_3)
138
+ conv8_3 = self.model8(conv7_3)
139
+ out_reg = self.model_out(self.softmax(conv8_3))
140
+
141
+ return self.unnormalize_ab(self.upsample4(out_reg))
142
+
143
+
144
+ # model_net = torch.load(f'weights/colorizer.pt')
145
+ model = ECCVGenerator()
146
+ model_net.load_state_dict(torch.load(f'weights/colorizer.pt'))
147
 
148
 
149
  def resize_original(img: Image.Image):