jamino30 commited on
Commit
3f52b3e
1 Parent(s): 71cc23e

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. vgg16.py +25 -0
  3. vgg19.py +6 -1
app.py CHANGED
@@ -9,14 +9,14 @@ import torch.optim as optim
9
  import gradio as gr
10
 
11
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
12
- from vgg19 import VGG_19
13
 
14
  if torch.cuda.is_available(): device = 'cuda'
15
  elif torch.backends.mps.is_available(): device = 'mps'
16
  else: device = 'cpu'
17
  print('DEVICE:', device)
18
 
19
- model = VGG_19().to(device)
20
  for param in model.parameters():
21
  param.requires_grad = False
22
 
 
9
  import gradio as gr
10
 
11
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
12
+ from vgg16 import VGG_16
13
 
14
  if torch.cuda.is_available(): device = 'cuda'
15
  elif torch.backends.mps.is_available(): device = 'mps'
16
  else: device = 'cpu'
17
  print('DEVICE:', device)
18
 
19
+ model = VGG_16().to(device)
20
  for param in model.parameters():
21
  param.requires_grad = False
22
 
vgg16.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models as models
3
+
4
+ class VGG_16(nn.Module):
5
+ def __init__(self):
6
+ super(VGG_16, self).__init__()
7
+ self.model = models.vgg16(weights='DEFAULT').features[:30]
8
+
9
+ for i, _ in enumerate(self.model):
10
+ if i in [4, 9, 16, 23]:
11
+ self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
12
+
13
+ def forward(self, x):
14
+ features = []
15
+
16
+ for i, layer in enumerate(self.model):
17
+ x = layer(x)
18
+ if i in [0, 5, 10, 17, 24]:
19
+ features.append(x)
20
+ return features
21
+
22
+
23
+ if __name__ == '__main__':
24
+ model = VGG_16()
25
+ print(model)
vgg19.py CHANGED
@@ -17,4 +17,9 @@ class VGG_19(nn.Module):
17
  x = layer(x)
18
  if i in [0, 5, 10, 19, 28]:
19
  features.append(x)
20
- return features
 
 
 
 
 
 
17
  x = layer(x)
18
  if i in [0, 5, 10, 19, 28]:
19
  features.append(x)
20
+ return features
21
+
22
+
23
+ if __name__ == '__main__':
24
+ model = VGG_19()
25
+ print(model)