shellyriver commited on
Commit
0650a36
1 Parent(s): 67433bc

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +41 -0
model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class CNN(nn.Module):
6
+ def __init__(self, num_channel=1, num_classes=10, num_pixel=28):
7
+ super().__init__()
8
+ self.conv1 = nn.Conv2d(
9
+ num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
10
+ )
11
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
12
+ self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
13
+ self.act = nn.ReLU(inplace=True)
14
+
15
+ ###
16
+ ### X_out = floor{ 1 + (X_in + 2*padding - dilation*(kernel_size-1) - 1)/stride }
17
+ ###
18
+ X = num_pixel
19
+ X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
20
+ X = X / 2
21
+ X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
22
+ X = X / 2
23
+ X = int(X)
24
+
25
+ self.fc1 = nn.Linear(64 * X * X, 512)
26
+ self.fc2 = nn.Linear(512, num_classes)
27
+
28
+ def forward(self, x):
29
+ x = self.act(self.conv1(x))
30
+ x = self.maxpool(x)
31
+ x = self.act(self.conv2(x))
32
+ x = self.maxpool(x)
33
+ x = torch.flatten(x, 1)
34
+ x = self.act(self.fc1(x))
35
+ x = self.fc2(x)
36
+ return x
37
+
38
+ def get_model():
39
+ return CNN
40
+
41
+