Spaces:
Runtime error
Runtime error
Create CNNetwork
Browse files- __pycache__/cnn.cpython-39.pyc +0 -0
- cnn.py +66 -0
- notebooks/playground.ipynb +62 -14
__pycache__/cnn.cpython-39.pyc
ADDED
Binary file (1.37 kB). View file
|
|
cnn.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torchsummary import summary
|
3 |
+
|
4 |
+
|
5 |
+
class CNNetwork(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
# 4 conv blocks / flatten / linear / softmax
|
10 |
+
self.conv1 = nn.Sequential(
|
11 |
+
nn.Conv2d(
|
12 |
+
in_channels=1,
|
13 |
+
out_channels=16,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=1,
|
16 |
+
padding=2
|
17 |
+
),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.MaxPool2d(kernel_size=2)
|
20 |
+
)
|
21 |
+
self.conv2 = nn.Sequential(
|
22 |
+
nn.Conv2d(
|
23 |
+
in_channels=16,
|
24 |
+
out_channels=32,
|
25 |
+
kernel_size=3,
|
26 |
+
stride=1,
|
27 |
+
padding=2
|
28 |
+
),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.MaxPool2d(kernel_size=2)
|
31 |
+
)
|
32 |
+
self.conv3 = nn.Sequential(
|
33 |
+
nn.Conv2d(
|
34 |
+
in_channels=32,
|
35 |
+
out_channels=64,
|
36 |
+
kernel_size=3,
|
37 |
+
stride=1,
|
38 |
+
padding=2
|
39 |
+
),
|
40 |
+
nn.ReLU(),
|
41 |
+
nn.MaxPool2d(kernel_size=2)
|
42 |
+
)
|
43 |
+
self.conv4 = nn.Sequential(
|
44 |
+
nn.Conv2d(
|
45 |
+
in_channels=64,
|
46 |
+
out_channels=128,
|
47 |
+
kernel_size=3,
|
48 |
+
stride=1,
|
49 |
+
padding=2
|
50 |
+
),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.MaxPool2d(kernel_size=2)
|
53 |
+
)
|
54 |
+
self.flatten = nn.Flatten()
|
55 |
+
self.linear = nn.Linear(128 * 5 * 4, 10)
|
56 |
+
self.softmax = nn.Softmax(dim=1)
|
57 |
+
|
58 |
+
def forward(self, input_data):
|
59 |
+
x = self.conv1(input_data)
|
60 |
+
x = self.conv2(x)
|
61 |
+
x = self.conv3(x)
|
62 |
+
x = self.conv4(x)
|
63 |
+
x = self.flatten(x)
|
64 |
+
logits = self.linear(x)
|
65 |
+
predictions = self.softmax(logits)
|
66 |
+
return predictions
|
notebooks/playground.ipynb
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
-
"id": "
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
-
"id": "
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -24,30 +24,32 @@
|
|
24 |
},
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
-
"execution_count":
|
28 |
-
"id": "
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
"import os\n",
|
33 |
"\n",
|
34 |
-
"import torch"
|
|
|
35 |
]
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
-
"execution_count":
|
40 |
-
"id": "
|
41 |
"metadata": {},
|
42 |
"outputs": [],
|
43 |
"source": [
|
44 |
-
"from dataset import
|
|
|
45 |
]
|
46 |
},
|
47 |
{
|
48 |
"cell_type": "code",
|
49 |
"execution_count": 78,
|
50 |
-
"id": "
|
51 |
"metadata": {},
|
52 |
"outputs": [
|
53 |
{
|
@@ -69,7 +71,7 @@
|
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
"execution_count": 80,
|
72 |
-
"id": "
|
73 |
"metadata": {},
|
74 |
"outputs": [],
|
75 |
"source": [
|
@@ -85,7 +87,7 @@
|
|
85 |
{
|
86 |
"cell_type": "code",
|
87 |
"execution_count": 81,
|
88 |
-
"id": "
|
89 |
"metadata": {},
|
90 |
"outputs": [
|
91 |
{
|
@@ -106,7 +108,7 @@
|
|
106 |
{
|
107 |
"cell_type": "code",
|
108 |
"execution_count": 82,
|
109 |
-
"id": "
|
110 |
"metadata": {},
|
111 |
"outputs": [
|
112 |
{
|
@@ -134,7 +136,7 @@
|
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
"execution_count": 83,
|
137 |
-
"id": "
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
140 |
{
|
@@ -152,10 +154,56 @@
|
|
152 |
"dataset[0][0].shape"
|
153 |
]
|
154 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
{
|
156 |
"cell_type": "code",
|
157 |
"execution_count": null,
|
158 |
-
"id": "
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": []
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
+
"id": "46dbbffd",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
+
"id": "56056453",
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
24 |
},
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
+
"execution_count": 86,
|
28 |
+
"id": "b5cbac9d",
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
"import os\n",
|
33 |
"\n",
|
34 |
+
"import torch\n",
|
35 |
+
"from torchsummary import summary"
|
36 |
]
|
37 |
},
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
+
"execution_count": 85,
|
41 |
+
"id": "1970ad63",
|
42 |
"metadata": {},
|
43 |
"outputs": [],
|
44 |
"source": [
|
45 |
+
"from dataset import *\n",
|
46 |
+
"from cnn import CNNetwork"
|
47 |
]
|
48 |
},
|
49 |
{
|
50 |
"cell_type": "code",
|
51 |
"execution_count": 78,
|
52 |
+
"id": "c28b0c5e",
|
53 |
"metadata": {},
|
54 |
"outputs": [
|
55 |
{
|
|
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
"execution_count": 80,
|
74 |
+
"id": "69025839",
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
|
|
87 |
{
|
88 |
"cell_type": "code",
|
89 |
"execution_count": 81,
|
90 |
+
"id": "8dfbb1b4",
|
91 |
"metadata": {},
|
92 |
"outputs": [
|
93 |
{
|
|
|
108 |
{
|
109 |
"cell_type": "code",
|
110 |
"execution_count": 82,
|
111 |
+
"id": "1071e53d",
|
112 |
"metadata": {},
|
113 |
"outputs": [
|
114 |
{
|
|
|
136 |
{
|
137 |
"cell_type": "code",
|
138 |
"execution_count": 83,
|
139 |
+
"id": "7a6d8133",
|
140 |
"metadata": {},
|
141 |
"outputs": [
|
142 |
{
|
|
|
154 |
"dataset[0][0].shape"
|
155 |
]
|
156 |
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"execution_count": 87,
|
160 |
+
"id": "4b8f75a0",
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [
|
163 |
+
{
|
164 |
+
"name": "stdout",
|
165 |
+
"output_type": "stream",
|
166 |
+
"text": [
|
167 |
+
"----------------------------------------------------------------\n",
|
168 |
+
" Layer (type) Output Shape Param #\n",
|
169 |
+
"================================================================\n",
|
170 |
+
" Conv2d-1 [-1, 16, 66, 46] 160\n",
|
171 |
+
" ReLU-2 [-1, 16, 66, 46] 0\n",
|
172 |
+
" MaxPool2d-3 [-1, 16, 33, 23] 0\n",
|
173 |
+
" Conv2d-4 [-1, 32, 35, 25] 4,640\n",
|
174 |
+
" ReLU-5 [-1, 32, 35, 25] 0\n",
|
175 |
+
" MaxPool2d-6 [-1, 32, 17, 12] 0\n",
|
176 |
+
" Conv2d-7 [-1, 64, 19, 14] 18,496\n",
|
177 |
+
" ReLU-8 [-1, 64, 19, 14] 0\n",
|
178 |
+
" MaxPool2d-9 [-1, 64, 9, 7] 0\n",
|
179 |
+
" Conv2d-10 [-1, 128, 11, 9] 73,856\n",
|
180 |
+
" ReLU-11 [-1, 128, 11, 9] 0\n",
|
181 |
+
" MaxPool2d-12 [-1, 128, 5, 4] 0\n",
|
182 |
+
" Flatten-13 [-1, 2560] 0\n",
|
183 |
+
" Linear-14 [-1, 10] 25,610\n",
|
184 |
+
" Softmax-15 [-1, 10] 0\n",
|
185 |
+
"================================================================\n",
|
186 |
+
"Total params: 122,762\n",
|
187 |
+
"Trainable params: 122,762\n",
|
188 |
+
"Non-trainable params: 0\n",
|
189 |
+
"----------------------------------------------------------------\n",
|
190 |
+
"Input size (MB): 0.01\n",
|
191 |
+
"Forward/backward pass size (MB): 1.83\n",
|
192 |
+
"Params size (MB): 0.47\n",
|
193 |
+
"Estimated Total Size (MB): 2.31\n",
|
194 |
+
"----------------------------------------------------------------\n"
|
195 |
+
]
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"source": [
|
199 |
+
"cnn = CNNetwork()\n",
|
200 |
+
"summary(cnn, (1, 64, 44))"
|
201 |
+
]
|
202 |
+
},
|
203 |
{
|
204 |
"cell_type": "code",
|
205 |
"execution_count": null,
|
206 |
+
"id": "888726ed",
|
207 |
"metadata": {},
|
208 |
"outputs": [],
|
209 |
"source": []
|