amanmibra commited on
Commit
e5db3e9
1 Parent(s): 3806d0c

Create CNNetwork

Browse files
__pycache__/cnn.cpython-39.pyc ADDED
Binary file (1.37 kB). View file
 
cnn.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchsummary import summary
3
+
4
+
5
+ class CNNetwork(nn.Module):
6
+
7
+ def __init__(self):
8
+ super().__init__()
9
+ # 4 conv blocks / flatten / linear / softmax
10
+ self.conv1 = nn.Sequential(
11
+ nn.Conv2d(
12
+ in_channels=1,
13
+ out_channels=16,
14
+ kernel_size=3,
15
+ stride=1,
16
+ padding=2
17
+ ),
18
+ nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2)
20
+ )
21
+ self.conv2 = nn.Sequential(
22
+ nn.Conv2d(
23
+ in_channels=16,
24
+ out_channels=32,
25
+ kernel_size=3,
26
+ stride=1,
27
+ padding=2
28
+ ),
29
+ nn.ReLU(),
30
+ nn.MaxPool2d(kernel_size=2)
31
+ )
32
+ self.conv3 = nn.Sequential(
33
+ nn.Conv2d(
34
+ in_channels=32,
35
+ out_channels=64,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=2
39
+ ),
40
+ nn.ReLU(),
41
+ nn.MaxPool2d(kernel_size=2)
42
+ )
43
+ self.conv4 = nn.Sequential(
44
+ nn.Conv2d(
45
+ in_channels=64,
46
+ out_channels=128,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=2
50
+ ),
51
+ nn.ReLU(),
52
+ nn.MaxPool2d(kernel_size=2)
53
+ )
54
+ self.flatten = nn.Flatten()
55
+ self.linear = nn.Linear(128 * 5 * 4, 10)
56
+ self.softmax = nn.Softmax(dim=1)
57
+
58
+ def forward(self, input_data):
59
+ x = self.conv1(input_data)
60
+ x = self.conv2(x)
61
+ x = self.conv3(x)
62
+ x = self.conv4(x)
63
+ x = self.flatten(x)
64
+ logits = self.linear(x)
65
+ predictions = self.softmax(logits)
66
+ return predictions
notebooks/playground.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
- "id": "7f11e761",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -14,7 +14,7 @@
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
- "id": "f3deb79d",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
@@ -24,30 +24,32 @@
24
  },
25
  {
26
  "cell_type": "code",
27
- "execution_count": 76,
28
- "id": "eb9888a5",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
32
  "import os\n",
33
  "\n",
34
- "import torch"
 
35
  ]
36
  },
37
  {
38
  "cell_type": "code",
39
- "execution_count": 77,
40
- "id": "75440e63",
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
44
- "from dataset import *"
 
45
  ]
46
  },
47
  {
48
  "cell_type": "code",
49
  "execution_count": 78,
50
- "id": "5b51f712",
51
  "metadata": {},
52
  "outputs": [
53
  {
@@ -69,7 +71,7 @@
69
  {
70
  "cell_type": "code",
71
  "execution_count": 80,
72
- "id": "253f87d6",
73
  "metadata": {},
74
  "outputs": [],
75
  "source": [
@@ -85,7 +87,7 @@
85
  {
86
  "cell_type": "code",
87
  "execution_count": 81,
88
- "id": "3d5c127a",
89
  "metadata": {},
90
  "outputs": [
91
  {
@@ -106,7 +108,7 @@
106
  {
107
  "cell_type": "code",
108
  "execution_count": 82,
109
- "id": "cbac184f",
110
  "metadata": {},
111
  "outputs": [
112
  {
@@ -134,7 +136,7 @@
134
  {
135
  "cell_type": "code",
136
  "execution_count": 83,
137
- "id": "2bd8c582",
138
  "metadata": {},
139
  "outputs": [
140
  {
@@ -152,10 +154,56 @@
152
  "dataset[0][0].shape"
153
  ]
154
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  {
156
  "cell_type": "code",
157
  "execution_count": null,
158
- "id": "c3c7b1d4",
159
  "metadata": {},
160
  "outputs": [],
161
  "source": []
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
+ "id": "46dbbffd",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
+ "id": "56056453",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
24
  },
25
  {
26
  "cell_type": "code",
27
+ "execution_count": 86,
28
+ "id": "b5cbac9d",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
32
  "import os\n",
33
  "\n",
34
+ "import torch\n",
35
+ "from torchsummary import summary"
36
  ]
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": 85,
41
+ "id": "1970ad63",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
45
+ "from dataset import *\n",
46
+ "from cnn import CNNetwork"
47
  ]
48
  },
49
  {
50
  "cell_type": "code",
51
  "execution_count": 78,
52
+ "id": "c28b0c5e",
53
  "metadata": {},
54
  "outputs": [
55
  {
 
71
  {
72
  "cell_type": "code",
73
  "execution_count": 80,
74
+ "id": "69025839",
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
 
87
  {
88
  "cell_type": "code",
89
  "execution_count": 81,
90
+ "id": "8dfbb1b4",
91
  "metadata": {},
92
  "outputs": [
93
  {
 
108
  {
109
  "cell_type": "code",
110
  "execution_count": 82,
111
+ "id": "1071e53d",
112
  "metadata": {},
113
  "outputs": [
114
  {
 
136
  {
137
  "cell_type": "code",
138
  "execution_count": 83,
139
+ "id": "7a6d8133",
140
  "metadata": {},
141
  "outputs": [
142
  {
 
154
  "dataset[0][0].shape"
155
  ]
156
  },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 87,
160
+ "id": "4b8f75a0",
161
+ "metadata": {},
162
+ "outputs": [
163
+ {
164
+ "name": "stdout",
165
+ "output_type": "stream",
166
+ "text": [
167
+ "----------------------------------------------------------------\n",
168
+ " Layer (type) Output Shape Param #\n",
169
+ "================================================================\n",
170
+ " Conv2d-1 [-1, 16, 66, 46] 160\n",
171
+ " ReLU-2 [-1, 16, 66, 46] 0\n",
172
+ " MaxPool2d-3 [-1, 16, 33, 23] 0\n",
173
+ " Conv2d-4 [-1, 32, 35, 25] 4,640\n",
174
+ " ReLU-5 [-1, 32, 35, 25] 0\n",
175
+ " MaxPool2d-6 [-1, 32, 17, 12] 0\n",
176
+ " Conv2d-7 [-1, 64, 19, 14] 18,496\n",
177
+ " ReLU-8 [-1, 64, 19, 14] 0\n",
178
+ " MaxPool2d-9 [-1, 64, 9, 7] 0\n",
179
+ " Conv2d-10 [-1, 128, 11, 9] 73,856\n",
180
+ " ReLU-11 [-1, 128, 11, 9] 0\n",
181
+ " MaxPool2d-12 [-1, 128, 5, 4] 0\n",
182
+ " Flatten-13 [-1, 2560] 0\n",
183
+ " Linear-14 [-1, 10] 25,610\n",
184
+ " Softmax-15 [-1, 10] 0\n",
185
+ "================================================================\n",
186
+ "Total params: 122,762\n",
187
+ "Trainable params: 122,762\n",
188
+ "Non-trainable params: 0\n",
189
+ "----------------------------------------------------------------\n",
190
+ "Input size (MB): 0.01\n",
191
+ "Forward/backward pass size (MB): 1.83\n",
192
+ "Params size (MB): 0.47\n",
193
+ "Estimated Total Size (MB): 2.31\n",
194
+ "----------------------------------------------------------------\n"
195
+ ]
196
+ }
197
+ ],
198
+ "source": [
199
+ "cnn = CNNetwork()\n",
200
+ "summary(cnn, (1, 64, 44))"
201
+ ]
202
+ },
203
  {
204
  "cell_type": "code",
205
  "execution_count": null,
206
+ "id": "888726ed",
207
  "metadata": {},
208
  "outputs": [],
209
  "source": []