No application file
No application file
trying to do Autoencoder but failing
Browse files- CNN-Autoencoder.ipynb +476 -0
@@ -0,0 +1,476 @@
1 |
2 |
"cells": [
3 |
4 |
"cell_type": "code",
5 |
"execution_count": null,
6 |
"id": "4f403af3",
7 |
"metadata": {},
8 |
"outputs": [],
9 |
"source": [
10 |
11 |
12 |
13 |
14 |
"cell_type": "code",
15 |
"execution_count": 46,
16 |
"id": "add961d3",
17 |
"metadata": {},
18 |
"outputs": [],
19 |
"source": [
20 |
"import matplotlib.pyplot as plt # plotting library\n",
21 |
"from sklearn.model_selection import train_test_split\n",
22 |
"import numpy as np # this module is useful to work with numerical arrays\n",
23 |
"import pandas as pd \n",
24 |
"import random \n",
25 |
"import os\n",
26 |
"import torch\n",
27 |
"import torchvision\n",
28 |
"from torchvision import transforms, datasets\n",
29 |
"from import DataLoader,random_split\n",
30 |
"from torch import nn\n",
31 |
"import torch.nn.functional as F\n",
32 |
"import torch.optim as optim"
33 |
34 |
35 |
36 |
"cell_type": "code",
37 |
"execution_count": 3,
38 |
"id": "7f5313b5",
39 |
"metadata": {},
40 |
"outputs": [],
41 |
"source": [
42 |
"def find_candidate_images(images_path):\n",
43 |
" \"\"\"\n",
44 |
" Finds all candidate images in the given folder and its sub-folders.\n",
45 |
46 |
" Returns:\n",
47 |
" images: a list of absolute paths to the discovered images.\n",
48 |
" \"\"\"\n",
49 |
" images = []\n",
50 |
" for root, dirs, files in os.walk(images_path):\n",
51 |
" for name in files:\n",
52 |
" file_path = os.path.abspath(os.path.join(root, name))\n",
53 |
" if ((os.path.splitext(name)[1]).lower() in ['.jpg','.png','.jpeg']):\n",
54 |
" images.append(file_path)\n",
55 |
" return images"
56 |
57 |
58 |
59 |
"cell_type": "code",
60 |
"execution_count": 49,
61 |
"id": "1e7f0096",
62 |
"metadata": {},
63 |
"outputs": [],
64 |
"source": [
65 |
"class MyDataset(\n",
66 |
" def __init__(self, img_list, augmentations):\n",
67 |
" super(MyDataset, self).__init__()\n",
68 |
" self.img_list = img_list\n",
69 |
" self.augmentations = augmentations\n",
70 |
71 |
" def __len__(self):\n",
72 |
" return len(self.img_list)\n",
73 |
74 |
" def __getitem__(self, idx):\n",
75 |
" img = self.img_list[idx]\n",
76 |
" return self.augmentations(img)"
77 |
78 |
79 |
80 |
"cell_type": "code",
81 |
"execution_count": 51,
82 |
"id": "f846b86c",
83 |
"metadata": {},
84 |
"outputs": [],
85 |
"source": [
86 |
"images = find_candidate_images('../SD_sample_f_m_pt2')"
87 |
88 |
89 |
90 |
"cell_type": "code",
91 |
"execution_count": 43,
92 |
"id": "da000292",
93 |
"metadata": {},
94 |
"outputs": [],
95 |
"source": [
96 |
"transform = transforms.Compose([\n",
97 |
98 |
99 |
100 |
101 |
102 |
"cell_type": "code",
103 |
"execution_count": 55,
104 |
"id": "d8f46911",
105 |
"metadata": {},
106 |
"outputs": [],
107 |
"source": [
108 |
"data = MyDataset(images, transform)\n",
109 |
"dataset_iterator = DataLoader(data, batch_size=1)"
110 |
111 |
112 |
113 |
"cell_type": "code",
114 |
"execution_count": 56,
115 |
"id": "05504c87",
116 |
"metadata": {},
117 |
"outputs": [
118 |
119 |
"ename": "TypeError",
120 |
"evalue": "pic should be PIL Image or ndarray. Got <class 'str'>",
121 |
"output_type": "error",
122 |
"traceback": [
123 |
124 |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
125 |
"Input \u001b[0;32mIn [56]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_images, test_images \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_test_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.33\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(train_images))\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(test_images))\n",
126 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/model_selection/\u001b[0m, in \u001b[0;36mtrain_test_split\u001b[0;34m(test_size, train_size, random_state, shuffle, stratify, *arrays)\u001b[0m\n\u001b[1;32m 2467\u001b[0m cv \u001b[38;5;241m=\u001b[39m CVClass(test_size\u001b[38;5;241m=\u001b[39mn_test, train_size\u001b[38;5;241m=\u001b[39mn_train, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m 2469\u001b[0m train, test \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(cv\u001b[38;5;241m.\u001b[39msplit(X\u001b[38;5;241m=\u001b[39marrays[\u001b[38;5;241m0\u001b[39m], y\u001b[38;5;241m=\u001b[39mstratify))\n\u001b[0;32m-> 2471\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2472\u001b[0m \u001b[43m \u001b[49m\u001b[43mchain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_iterable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2473\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43marrays\u001b[49m\n\u001b[1;32m 2474\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2475\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
127 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/model_selection/\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 2467\u001b[0m cv \u001b[38;5;241m=\u001b[39m CVClass(test_size\u001b[38;5;241m=\u001b[39mn_test, train_size\u001b[38;5;241m=\u001b[39mn_train, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m 2469\u001b[0m train, test \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(cv\u001b[38;5;241m.\u001b[39msplit(X\u001b[38;5;241m=\u001b[39marrays[\u001b[38;5;241m0\u001b[39m], y\u001b[38;5;241m=\u001b[39mstratify))\n\u001b[1;32m 2471\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 2472\u001b[0m chain\u001b[38;5;241m.\u001b[39mfrom_iterable(\n\u001b[0;32m-> 2473\u001b[0m (\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m)\u001b[49m, _safe_indexing(a, test)) \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m arrays\n\u001b[1;32m 2474\u001b[0m )\n\u001b[1;32m 2475\u001b[0m )\n",
128 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/\u001b[0m, in \u001b[0;36m_safe_indexing\u001b[0;34m(X, indices, axis)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _array_indexing(X, indices, indices_dtype, axis\u001b[38;5;241m=\u001b[39maxis)\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_list_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices_dtype\u001b[49m\u001b[43m)\u001b[49m\n",
129 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/\u001b[0m, in \u001b[0;36m_list_indexing\u001b[0;34m(X, key, key_dtype)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(compress(X, key))\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# key is a integer array-like of key\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [X[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m key]\n",
130 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(compress(X, key))\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# key is a integer array-like of key\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[43mX\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m key]\n",
131 |
"Input \u001b[0;32mIn [49]\u001b[0m, in \u001b[0;36mMyDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m 11\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimg_list[idx]\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maugmentations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n",
132 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[0;32m---> 95\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
133 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/\u001b[0m, in \u001b[0;36mToTensor.__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, pic):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 129\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;124;03m pic (PIL Image or numpy.ndarray): Image to be converted to tensor.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;124;03m Tensor: Converted image.\u001b[39;00m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpic\u001b[49m\u001b[43m)\u001b[49m\n",
134 |
"File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/\u001b[0m, in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m 135\u001b[0m _log_api_usage_once(to_tensor)\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (F_pil\u001b[38;5;241m.\u001b[39m_is_pil_image(pic) \u001b[38;5;129;01mor\u001b[39;00m _is_numpy(pic)):\n\u001b[0;32m--> 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpic should be PIL Image or ndarray. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(pic)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_numpy(pic) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_numpy_image(pic):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpic should be 2/3 dimensional. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpic\u001b[38;5;241m.\u001b[39mndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dimensions.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
135 |
"\u001b[0;31mTypeError\u001b[0m: pic should be PIL Image or ndarray. Got <class 'str'>"
136 |
137 |
138 |
139 |
"source": [
140 |
"train_images, test_images = train_test_split(data, test_size=0.33, random_state=42)\n",
141 |
142 |
143 |
144 |
145 |
146 |
"cell_type": "code",
147 |
"execution_count": 16,
148 |
"id": "669f82ab",
149 |
"metadata": {},
150 |
"outputs": [],
151 |
"source": [
152 |
153 |
154 |
155 |
156 |
"cell_type": "code",
157 |
"execution_count": 23,
158 |
"id": "e962953c",
159 |
"metadata": {},
160 |
"outputs": [],
161 |
"source": [
162 |
"train_data, val_data = random_split(train_images, [int(m-m*0.2), int(m*0.2)])\n",
163 |
"test_dataset = test_images"
164 |
165 |
166 |
167 |
"cell_type": "code",
168 |
"execution_count": 24,
169 |
"id": "16a8e2a1",
170 |
"metadata": {},
171 |
"outputs": [],
172 |
"source": [
173 |
"train_loader =, batch_size=batch_size)\n",
174 |
"valid_loader =, batch_size=batch_size)\n",
175 |
"test_loader =, batch_size=batch_size,shuffle=True)"
176 |
177 |
178 |
179 |
"cell_type": "code",
180 |
"execution_count": 25,
181 |
"id": "07403239",
182 |
"metadata": {},
183 |
"outputs": [],
184 |
"source": [
185 |
"class Encoder(nn.Module):\n",
186 |
" \n",
187 |
" def __init__(self, encoded_space_dim,fc2_input_dim):\n",
188 |
" super().__init__()\n",
189 |
" \n",
190 |
" ### Convolutional section\n",
191 |
" self.encoder_cnn = nn.Sequential(\n",
192 |
" nn.Conv2d(1, 8, 3, stride=2, padding=1),\n",
193 |
" nn.ReLU(True),\n",
194 |
" nn.Conv2d(8, 16, 3, stride=2, padding=1),\n",
195 |
" nn.BatchNorm2d(16),\n",
196 |
" nn.ReLU(True),\n",
197 |
" nn.Conv2d(16, 32, 3, stride=2, padding=0),\n",
198 |
" nn.ReLU(True)\n",
199 |
" )\n",
200 |
" \n",
201 |
" ### Flatten layer\n",
202 |
" self.flatten = nn.Flatten(start_dim=1)\n",
203 |
"### Linear section\n",
204 |
" self.encoder_lin = nn.Sequential(\n",
205 |
" nn.Linear(3 * 3 * 32, 128),\n",
206 |
" nn.ReLU(True),\n",
207 |
" nn.Linear(128, encoded_space_dim)\n",
208 |
" )\n",
209 |
" \n",
210 |
" def forward(self, x):\n",
211 |
" x = self.encoder_cnn(x)\n",
212 |
" x = self.flatten(x)\n",
213 |
" x = self.encoder_lin(x)\n",
214 |
" return x\n",
215 |
"class Decoder(nn.Module):\n",
216 |
" \n",
217 |
" def __init__(self, encoded_space_dim,fc2_input_dim):\n",
218 |
" super().__init__()\n",
219 |
" self.decoder_lin = nn.Sequential(\n",
220 |
" nn.Linear(encoded_space_dim, 128),\n",
221 |
" nn.ReLU(True),\n",
222 |
" nn.Linear(128, 3 * 3 * 32),\n",
223 |
" nn.ReLU(True)\n",
224 |
" )\n",
225 |
226 |
" self.unflatten = nn.Unflatten(dim=1, \n",
227 |
" unflattened_size=(32, 3, 3))\n",
228 |
229 |
" self.decoder_conv = nn.Sequential(\n",
230 |
" nn.ConvTranspose2d(32, 16, 3, \n",
231 |
" stride=2, output_padding=0),\n",
232 |
" nn.BatchNorm2d(16),\n",
233 |
" nn.ReLU(True),\n",
234 |
" nn.ConvTranspose2d(16, 8, 3, stride=2, \n",
235 |
" padding=1, output_padding=1),\n",
236 |
" nn.BatchNorm2d(8),\n",
237 |
" nn.ReLU(True),\n",
238 |
" nn.ConvTranspose2d(8, 1, 3, stride=2, \n",
239 |
" padding=1, output_padding=1)\n",
240 |
" )\n",
241 |
" \n",
242 |
" def forward(self, x):\n",
243 |
" x = self.decoder_lin(x)\n",
244 |
" x = self.unflatten(x)\n",
245 |
" x = self.decoder_conv(x)\n",
246 |
" x = torch.sigmoid(x)\n",
247 |
" return x"
248 |
249 |
250 |
251 |
"cell_type": "code",
252 |
"execution_count": 26,
253 |
"id": "fedfd708",
254 |
"metadata": {},
255 |
"outputs": [
256 |
257 |
"name": "stdout",
258 |
"output_type": "stream",
259 |
"text": [
260 |
"Selected device: cuda\n"
261 |
262 |
263 |
264 |
"data": {
265 |
"text/plain": [
266 |
267 |
" (decoder_lin): Sequential(\n",
268 |
" (0): Linear(in_features=4, out_features=128, bias=True)\n",
269 |
" (1): ReLU(inplace=True)\n",
270 |
" (2): Linear(in_features=128, out_features=288, bias=True)\n",
271 |
" (3): ReLU(inplace=True)\n",
272 |
" )\n",
273 |
" (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))\n",
274 |
" (decoder_conv): Sequential(\n",
275 |
" (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))\n",
276 |
" (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
277 |
" (2): ReLU(inplace=True)\n",
278 |
" (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
279 |
" (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280 |
" (5): ReLU(inplace=True)\n",
281 |
" (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
282 |
" )\n",
283 |
284 |
285 |
286 |
"execution_count": 26,
287 |
"metadata": {},
288 |
"output_type": "execute_result"
289 |
290 |
291 |
"source": [
292 |
"### Define the loss function\n",
293 |
"loss_fn = torch.nn.MSELoss()\n",
294 |
295 |
"### Define an optimizer (both for the encoder and the decoder!)\n",
296 |
"lr= 0.001\n",
297 |
298 |
"### Set the random seed for reproducible results\n",
299 |
300 |
301 |
"### Initialize the two networks\n",
302 |
"d = 4\n",
303 |
304 |
"#model = Autoencoder(encoded_space_dim=encoded_space_dim)\n",
305 |
"encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)\n",
306 |
"decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)\n",
307 |
"params_to_optimize = [\n",
308 |
" {'params': encoder.parameters()},\n",
309 |
" {'params': decoder.parameters()}\n",
310 |
311 |
312 |
"optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)\n",
313 |
314 |
"# Check if the GPU is available\n",
315 |
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
316 |
"print(f'Selected device: {device}')\n",
317 |
318 |
"# Move both the encoder and the decoder to the selected device\n",
319 |
320 |
321 |
322 |
323 |
324 |
"cell_type": "code",
325 |
"execution_count": 33,
326 |
"id": "bae32de2",
327 |
"metadata": {},
328 |
"outputs": [],
329 |
"source": [
330 |
"### Training function\n",
331 |
"def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):\n",
332 |
" # Set train mode for both the encoder and the decoder\n",
333 |
" encoder.train()\n",
334 |
" decoder.train()\n",
335 |
" train_loss = []\n",
336 |
" # Iterate the dataloader (we do not need the label values, this is unsupervised learning)\n",
337 |
" for image_batch, _ in dataloader: # with \"_\" we just ignore the labels (the second element of the dataloader tuple)\n",
338 |
" # Move tensor to the proper device\n",
339 |
" image_batch =\n",
340 |
" # Encode data\n",
341 |
" encoded_data = encoder(image_batch)\n",
342 |
" # Decode data\n",
343 |
" decoded_data = decoder(encoded_data)\n",
344 |
" # Evaluate loss\n",
345 |
" loss = loss_fn(decoded_data, image_batch)\n",
346 |
" # Backward pass\n",
347 |
" optimizer.zero_grad()\n",
348 |
" loss.backward()\n",
349 |
" optimizer.step()\n",
350 |
" # Print batch loss\n",
351 |
" print('\\t partial train loss (single batch): %f' % (\n",
352 |
" train_loss.append(loss.detach().cpu().numpy())\n",
353 |
354 |
" return np.mean(train_loss)"
355 |
356 |
357 |
358 |
"cell_type": "code",
359 |
"execution_count": 28,
360 |
"id": "ff2ec5fd",
361 |
"metadata": {},
362 |
"outputs": [],
363 |
"source": [
364 |
"### Testing function\n",
365 |
"def test_epoch(encoder, decoder, device, dataloader, loss_fn):\n",
366 |
" # Set evaluation mode for encoder and decoder\n",
367 |
" encoder.eval()\n",
368 |
" decoder.eval()\n",
369 |
" with torch.no_grad(): # No need to track the gradients\n",
370 |
" # Define the lists to store the outputs for each batch\n",
371 |
" conc_out = []\n",
372 |
" conc_label = []\n",
373 |
" for image_batch, _ in dataloader:\n",
374 |
" # Move tensor to the proper device\n",
375 |
" image_batch =\n",
376 |
" # Encode data\n",
377 |
" encoded_data = encoder(image_batch)\n",
378 |
" # Decode data\n",
379 |
" decoded_data = decoder(encoded_data)\n",
380 |
" # Append the network output and the original image to the lists\n",
381 |
" conc_out.append(decoded_data.cpu())\n",
382 |
" conc_label.append(image_batch.cpu())\n",
383 |
" # Create a single tensor with all the values in the lists\n",
384 |
" conc_out =\n",
385 |
" conc_label = \n",
386 |
" # Evaluate global loss\n",
387 |
" val_loss = loss_fn(conc_out, conc_label)\n",
388 |
" return"
389 |
390 |
391 |
392 |
"cell_type": "code",
393 |
"execution_count": 29,
394 |
"id": "592ab5f1",
395 |
"metadata": {},
396 |
"outputs": [],
397 |
"source": [
398 |
"def plot_ae_outputs(encoder,decoder,n=10):\n",
399 |
" plt.figure(figsize=(16,4.5))\n",
400 |
" targets = test_dataset.targets.numpy()\n",
401 |
" t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}\n",
402 |
" for i in range(n):\n",
403 |
" ax = plt.subplot(2,n,i+1)\n",
404 |
" img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)\n",
405 |
" encoder.eval()\n",
406 |
" decoder.eval()\n",
407 |
" with torch.no_grad():\n",
408 |
" rec_img = decoder(encoder(img))\n",
409 |
" plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')\n",
410 |
" ax.get_xaxis().set_visible(False)\n",
411 |
" ax.get_yaxis().set_visible(False) \n",
412 |
" if i == n//2:\n",
413 |
" ax.set_title('Original images')\n",
414 |
" ax = plt.subplot(2, n, i + 1 + n)\n",
415 |
" plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray') \n",
416 |
" ax.get_xaxis().set_visible(False)\n",
417 |
" ax.get_yaxis().set_visible(False) \n",
418 |
" if i == n//2:\n",
419 |
" ax.set_title('Reconstructed images')\n",
420 |
" "
421 |
422 |
423 |
424 |
"cell_type": "code",
425 |
"execution_count": 34,
426 |
"id": "5f8b646b",
427 |
"metadata": {},
428 |
"outputs": [
429 |
430 |
"ename": "ValueError",
431 |
"evalue": "too many values to unpack (expected 2)",
432 |
"output_type": "error",
433 |
"traceback": [
434 |
435 |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
436 |
"Input \u001b[0;32mIn [34]\u001b[0m, in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m diz_loss \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m'\u001b[39m:[],\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m:[]}\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[0;32m----> 4\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mencoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdecoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43mloss_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43moptim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m val_loss \u001b[38;5;241m=\u001b[39m test_epoch(encoder,decoder,device,test_loader,loss_fn)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m EPOCH \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m train loss \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m val loss \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, num_epochs,train_loss,val_loss))\n",
437 |
"Input \u001b[0;32mIn [33]\u001b[0m, in \u001b[0;36mtrain_epoch\u001b[0;34m(encoder, decoder, device, dataloader, loss_fn, optimizer)\u001b[0m\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Iterate the dataloader (we do not need the label values, this is unsupervised learning)\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image_batch, _ \u001b[38;5;129;01min\u001b[39;00m dataloader: \u001b[38;5;66;03m# with \"_\" we just ignore the labels (the second element of the dataloader tuple)\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Move tensor to the proper device\u001b[39;00m\n\u001b[1;32m 10\u001b[0m image_batch \u001b[38;5;241m=\u001b[39m image_batch\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Encode data\u001b[39;00m\n",
438 |
"\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
439 |
440 |
441 |
442 |
"source": [
443 |
"num_epochs = 30\n",
444 |
"diz_loss = {'train_loss':[],'val_loss':[]}\n",
445 |
"for epoch in range(num_epochs):\n",
446 |
" train_loss =train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)\n",
447 |
" val_loss = test_epoch(encoder,decoder,device,test_loader,loss_fn)\n",
448 |
" print('\\n EPOCH {}/{} \\t train loss {} \\t val loss {}'.format(epoch + 1, num_epochs,train_loss,val_loss))\n",
449 |
" diz_loss['train_loss'].append(train_loss)\n",
450 |
" diz_loss['val_loss'].append(val_loss)\n",
451 |
" plot_ae_outputs(encoder,decoder,n=10)"
452 |
453 |
454 |
455 |
"metadata": {
456 |
"kernelspec": {
457 |
"display_name": "Python 3 (ipykernel)",
458 |
"language": "python",
459 |
"name": "python3"
460 |
461 |
"language_info": {
462 |
"codemirror_mode": {
463 |
"name": "ipython",
464 |
"version": 3
465 |
466 |
"file_extension": ".py",
467 |
"mimetype": "text/x-python",
468 |
"name": "python",
469 |
"nbconvert_exporter": "python",
470 |
"pygments_lexer": "ipython3",
471 |
"version": "3.9.12"
472 |
473 |
474 |
"nbformat": 4,
475 |
"nbformat_minor": 5
476 |