sasha HF staff commited on
Commit
5d26c71
1 Parent(s): 8e00a6e

trying to do Autoencoder but failing

Browse files
Files changed (1) hide show
  1. CNN-Autoencoder.ipynb +476 -0
CNN-Autoencoder.ipynb ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "4f403af3",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "#Source: https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 46,
16
+ "id": "add961d3",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import matplotlib.pyplot as plt # plotting library\n",
21
+ "from sklearn.model_selection import train_test_split\n",
22
+ "import numpy as np # this module is useful to work with numerical arrays\n",
23
+ "import pandas as pd \n",
24
+ "import random \n",
25
+ "import os\n",
26
+ "import torch\n",
27
+ "import torchvision\n",
28
+ "from torchvision import transforms, datasets\n",
29
+ "from torch.utils.data import DataLoader,random_split\n",
30
+ "from torch import nn\n",
31
+ "import torch.nn.functional as F\n",
32
+ "import torch.optim as optim"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 3,
38
+ "id": "7f5313b5",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "def find_candidate_images(images_path):\n",
43
+ " \"\"\"\n",
44
+ " Finds all candidate images in the given folder and its sub-folders.\n",
45
+ "\n",
46
+ " Returns:\n",
47
+ " images: a list of absolute paths to the discovered images.\n",
48
+ " \"\"\"\n",
49
+ " images = []\n",
50
+ " for root, dirs, files in os.walk(images_path):\n",
51
+ " for name in files:\n",
52
+ " file_path = os.path.abspath(os.path.join(root, name))\n",
53
+ " if ((os.path.splitext(name)[1]).lower() in ['.jpg','.png','.jpeg']):\n",
54
+ " images.append(file_path)\n",
55
+ " return images"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 49,
61
+ "id": "1e7f0096",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "class MyDataset(torch.utils.data.Dataset):\n",
66
+ " def __init__(self, img_list, augmentations):\n",
67
+ " super(MyDataset, self).__init__()\n",
68
+ " self.img_list = img_list\n",
69
+ " self.augmentations = augmentations\n",
70
+ "\n",
71
+ " def __len__(self):\n",
72
+ " return len(self.img_list)\n",
73
+ "\n",
74
+ " def __getitem__(self, idx):\n",
75
+ " img = self.img_list[idx]\n",
76
+ " return self.augmentations(img)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 51,
82
+ "id": "f846b86c",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "images = find_candidate_images('../SD_sample_f_m_pt2')"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 43,
92
+ "id": "da000292",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "transform = transforms.Compose([\n",
97
+ "transforms.ToTensor(),\n",
98
+ "])"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 55,
104
+ "id": "d8f46911",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "data = MyDataset(images, transform)\n",
109
+ "dataset_iterator = DataLoader(data, batch_size=1)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 56,
115
+ "id": "05504c87",
116
+ "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "ename": "TypeError",
120
+ "evalue": "pic should be PIL Image or ndarray. Got <class 'str'>",
121
+ "output_type": "error",
122
+ "traceback": [
123
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
124
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
125
+ "Input \u001b[0;32mIn [56]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_images, test_images \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_test_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.33\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(train_images))\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(test_images))\n",
126
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/model_selection/_split.py:2471\u001b[0m, in \u001b[0;36mtrain_test_split\u001b[0;34m(test_size, train_size, random_state, shuffle, stratify, *arrays)\u001b[0m\n\u001b[1;32m 2467\u001b[0m cv \u001b[38;5;241m=\u001b[39m CVClass(test_size\u001b[38;5;241m=\u001b[39mn_test, train_size\u001b[38;5;241m=\u001b[39mn_train, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m 2469\u001b[0m train, test \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(cv\u001b[38;5;241m.\u001b[39msplit(X\u001b[38;5;241m=\u001b[39marrays[\u001b[38;5;241m0\u001b[39m], y\u001b[38;5;241m=\u001b[39mstratify))\n\u001b[0;32m-> 2471\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2472\u001b[0m \u001b[43m \u001b[49m\u001b[43mchain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_iterable\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2473\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43ma\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43marrays\u001b[49m\n\u001b[1;32m 2474\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2475\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
127
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/model_selection/_split.py:2473\u001b[0m, in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 2467\u001b[0m cv \u001b[38;5;241m=\u001b[39m CVClass(test_size\u001b[38;5;241m=\u001b[39mn_test, train_size\u001b[38;5;241m=\u001b[39mn_train, random_state\u001b[38;5;241m=\u001b[39mrandom_state)\n\u001b[1;32m 2469\u001b[0m train, test \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(cv\u001b[38;5;241m.\u001b[39msplit(X\u001b[38;5;241m=\u001b[39marrays[\u001b[38;5;241m0\u001b[39m], y\u001b[38;5;241m=\u001b[39mstratify))\n\u001b[1;32m 2471\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 2472\u001b[0m chain\u001b[38;5;241m.\u001b[39mfrom_iterable(\n\u001b[0;32m-> 2473\u001b[0m (\u001b[43m_safe_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m)\u001b[49m, _safe_indexing(a, test)) \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m arrays\n\u001b[1;32m 2474\u001b[0m )\n\u001b[1;32m 2475\u001b[0m )\n",
128
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/__init__.py:363\u001b[0m, in \u001b[0;36m_safe_indexing\u001b[0;34m(X, indices, axis)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _array_indexing(X, indices, indices_dtype, axis\u001b[38;5;241m=\u001b[39maxis)\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_list_indexing\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices_dtype\u001b[49m\u001b[43m)\u001b[49m\n",
129
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/__init__.py:217\u001b[0m, in \u001b[0;36m_list_indexing\u001b[0;34m(X, key, key_dtype)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(compress(X, key))\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# key is a integer array-like of key\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [X[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m key]\n",
130
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/sklearn/utils/__init__.py:217\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(compress(X, key))\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m# key is a integer array-like of key\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[43mX\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m key]\n",
131
+ "Input \u001b[0;32mIn [49]\u001b[0m, in \u001b[0;36mMyDataset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m 11\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mimg_list[idx]\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maugmentations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n",
132
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[0;32m---> 95\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m img\n",
133
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/transforms.py:135\u001b[0m, in \u001b[0;36mToTensor.__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, pic):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 129\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[1;32m 130\u001b[0m \u001b[38;5;124;03m pic (PIL Image or numpy.ndarray): Image to be converted to tensor.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;124;03m Tensor: Converted image.\u001b[39;00m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpic\u001b[49m\u001b[43m)\u001b[49m\n",
134
+ "File \u001b[0;32m~/miniconda3/envs/stablediffusion/lib/python3.9/site-packages/torchvision/transforms/functional.py:137\u001b[0m, in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m 135\u001b[0m _log_api_usage_once(to_tensor)\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (F_pil\u001b[38;5;241m.\u001b[39m_is_pil_image(pic) \u001b[38;5;129;01mor\u001b[39;00m _is_numpy(pic)):\n\u001b[0;32m--> 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpic should be PIL Image or ndarray. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(pic)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_numpy(pic) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_numpy_image(pic):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpic should be 2/3 dimensional. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpic\u001b[38;5;241m.\u001b[39mndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m dimensions.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
135
+ "\u001b[0;31mTypeError\u001b[0m: pic should be PIL Image or ndarray. Got <class 'str'>"
136
+ ]
137
+ }
138
+ ],
139
+ "source": [
140
+ "train_images, test_images = train_test_split(data, test_size=0.33, random_state=42)\n",
141
+ "print(len(train_images))\n",
142
+ "print(len(test_images))"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 16,
148
+ "id": "669f82ab",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "m=len(train_images)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 23,
158
+ "id": "e962953c",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "train_data, val_data = random_split(train_images, [int(m-m*0.2), int(m*0.2)])\n",
163
+ "test_dataset = test_images"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 24,
169
+ "id": "16a8e2a1",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)\n",
174
+ "valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)\n",
175
+ "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 25,
181
+ "id": "07403239",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "class Encoder(nn.Module):\n",
186
+ " \n",
187
+ " def __init__(self, encoded_space_dim,fc2_input_dim):\n",
188
+ " super().__init__()\n",
189
+ " \n",
190
+ " ### Convolutional section\n",
191
+ " self.encoder_cnn = nn.Sequential(\n",
192
+ " nn.Conv2d(1, 8, 3, stride=2, padding=1),\n",
193
+ " nn.ReLU(True),\n",
194
+ " nn.Conv2d(8, 16, 3, stride=2, padding=1),\n",
195
+ " nn.BatchNorm2d(16),\n",
196
+ " nn.ReLU(True),\n",
197
+ " nn.Conv2d(16, 32, 3, stride=2, padding=0),\n",
198
+ " nn.ReLU(True)\n",
199
+ " )\n",
200
+ " \n",
201
+ " ### Flatten layer\n",
202
+ " self.flatten = nn.Flatten(start_dim=1)\n",
203
+ "### Linear section\n",
204
+ " self.encoder_lin = nn.Sequential(\n",
205
+ " nn.Linear(3 * 3 * 32, 128),\n",
206
+ " nn.ReLU(True),\n",
207
+ " nn.Linear(128, encoded_space_dim)\n",
208
+ " )\n",
209
+ " \n",
210
+ " def forward(self, x):\n",
211
+ " x = self.encoder_cnn(x)\n",
212
+ " x = self.flatten(x)\n",
213
+ " x = self.encoder_lin(x)\n",
214
+ " return x\n",
215
+ "class Decoder(nn.Module):\n",
216
+ " \n",
217
+ " def __init__(self, encoded_space_dim,fc2_input_dim):\n",
218
+ " super().__init__()\n",
219
+ " self.decoder_lin = nn.Sequential(\n",
220
+ " nn.Linear(encoded_space_dim, 128),\n",
221
+ " nn.ReLU(True),\n",
222
+ " nn.Linear(128, 3 * 3 * 32),\n",
223
+ " nn.ReLU(True)\n",
224
+ " )\n",
225
+ "\n",
226
+ " self.unflatten = nn.Unflatten(dim=1, \n",
227
+ " unflattened_size=(32, 3, 3))\n",
228
+ "\n",
229
+ " self.decoder_conv = nn.Sequential(\n",
230
+ " nn.ConvTranspose2d(32, 16, 3, \n",
231
+ " stride=2, output_padding=0),\n",
232
+ " nn.BatchNorm2d(16),\n",
233
+ " nn.ReLU(True),\n",
234
+ " nn.ConvTranspose2d(16, 8, 3, stride=2, \n",
235
+ " padding=1, output_padding=1),\n",
236
+ " nn.BatchNorm2d(8),\n",
237
+ " nn.ReLU(True),\n",
238
+ " nn.ConvTranspose2d(8, 1, 3, stride=2, \n",
239
+ " padding=1, output_padding=1)\n",
240
+ " )\n",
241
+ " \n",
242
+ " def forward(self, x):\n",
243
+ " x = self.decoder_lin(x)\n",
244
+ " x = self.unflatten(x)\n",
245
+ " x = self.decoder_conv(x)\n",
246
+ " x = torch.sigmoid(x)\n",
247
+ " return x"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 26,
253
+ "id": "fedfd708",
254
+ "metadata": {},
255
+ "outputs": [
256
+ {
257
+ "name": "stdout",
258
+ "output_type": "stream",
259
+ "text": [
260
+ "Selected device: cuda\n"
261
+ ]
262
+ },
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "Decoder(\n",
267
+ " (decoder_lin): Sequential(\n",
268
+ " (0): Linear(in_features=4, out_features=128, bias=True)\n",
269
+ " (1): ReLU(inplace=True)\n",
270
+ " (2): Linear(in_features=128, out_features=288, bias=True)\n",
271
+ " (3): ReLU(inplace=True)\n",
272
+ " )\n",
273
+ " (unflatten): Unflatten(dim=1, unflattened_size=(32, 3, 3))\n",
274
+ " (decoder_conv): Sequential(\n",
275
+ " (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))\n",
276
+ " (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
277
+ " (2): ReLU(inplace=True)\n",
278
+ " (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
279
+ " (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " (5): ReLU(inplace=True)\n",
281
+ " (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))\n",
282
+ " )\n",
283
+ ")"
284
+ ]
285
+ },
286
+ "execution_count": 26,
287
+ "metadata": {},
288
+ "output_type": "execute_result"
289
+ }
290
+ ],
291
+ "source": [
292
+ "### Define the loss function\n",
293
+ "loss_fn = torch.nn.MSELoss()\n",
294
+ "\n",
295
+ "### Define an optimizer (both for the encoder and the decoder!)\n",
296
+ "lr= 0.001\n",
297
+ "\n",
298
+ "### Set the random seed for reproducible results\n",
299
+ "torch.manual_seed(0)\n",
300
+ "\n",
301
+ "### Initialize the two networks\n",
302
+ "d = 4\n",
303
+ "\n",
304
+ "#model = Autoencoder(encoded_space_dim=encoded_space_dim)\n",
305
+ "encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)\n",
306
+ "decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)\n",
307
+ "params_to_optimize = [\n",
308
+ " {'params': encoder.parameters()},\n",
309
+ " {'params': decoder.parameters()}\n",
310
+ "]\n",
311
+ "\n",
312
+ "optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)\n",
313
+ "\n",
314
+ "# Check if the GPU is available\n",
315
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
316
+ "print(f'Selected device: {device}')\n",
317
+ "\n",
318
+ "# Move both the encoder and the decoder to the selected device\n",
319
+ "encoder.to(device)\n",
320
+ "decoder.to(device)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 33,
326
+ "id": "bae32de2",
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "### Training function\n",
331
+ "def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):\n",
332
+ " # Set train mode for both the encoder and the decoder\n",
333
+ " encoder.train()\n",
334
+ " decoder.train()\n",
335
+ " train_loss = []\n",
336
+ " # Iterate the dataloader (we do not need the label values, this is unsupervised learning)\n",
337
+ " for image_batch, _ in dataloader: # with \"_\" we just ignore the labels (the second element of the dataloader tuple)\n",
338
+ " # Move tensor to the proper device\n",
339
+ " image_batch = image_batch.to(device)\n",
340
+ " # Encode data\n",
341
+ " encoded_data = encoder(image_batch)\n",
342
+ " # Decode data\n",
343
+ " decoded_data = decoder(encoded_data)\n",
344
+ " # Evaluate loss\n",
345
+ " loss = loss_fn(decoded_data, image_batch)\n",
346
+ " # Backward pass\n",
347
+ " optimizer.zero_grad()\n",
348
+ " loss.backward()\n",
349
+ " optimizer.step()\n",
350
+ " # Print batch loss\n",
351
+ " print('\\t partial train loss (single batch): %f' % (loss.data))\n",
352
+ " train_loss.append(loss.detach().cpu().numpy())\n",
353
+ "\n",
354
+ " return np.mean(train_loss)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": 28,
360
+ "id": "ff2ec5fd",
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "### Testing function\n",
365
+ "def test_epoch(encoder, decoder, device, dataloader, loss_fn):\n",
366
+ " # Set evaluation mode for encoder and decoder\n",
367
+ " encoder.eval()\n",
368
+ " decoder.eval()\n",
369
+ " with torch.no_grad(): # No need to track the gradients\n",
370
+ " # Define the lists to store the outputs for each batch\n",
371
+ " conc_out = []\n",
372
+ " conc_label = []\n",
373
+ " for image_batch, _ in dataloader:\n",
374
+ " # Move tensor to the proper device\n",
375
+ " image_batch = image_batch.to(device)\n",
376
+ " # Encode data\n",
377
+ " encoded_data = encoder(image_batch)\n",
378
+ " # Decode data\n",
379
+ " decoded_data = decoder(encoded_data)\n",
380
+ " # Append the network output and the original image to the lists\n",
381
+ " conc_out.append(decoded_data.cpu())\n",
382
+ " conc_label.append(image_batch.cpu())\n",
383
+ " # Create a single tensor with all the values in the lists\n",
384
+ " conc_out = torch.cat(conc_out)\n",
385
+ " conc_label = torch.cat(conc_label) \n",
386
+ " # Evaluate global loss\n",
387
+ " val_loss = loss_fn(conc_out, conc_label)\n",
388
+ " return val_loss.data"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": 29,
394
+ "id": "592ab5f1",
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "def plot_ae_outputs(encoder,decoder,n=10):\n",
399
+ " plt.figure(figsize=(16,4.5))\n",
400
+ " targets = test_dataset.targets.numpy()\n",
401
+ " t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}\n",
402
+ " for i in range(n):\n",
403
+ " ax = plt.subplot(2,n,i+1)\n",
404
+ " img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)\n",
405
+ " encoder.eval()\n",
406
+ " decoder.eval()\n",
407
+ " with torch.no_grad():\n",
408
+ " rec_img = decoder(encoder(img))\n",
409
+ " plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')\n",
410
+ " ax.get_xaxis().set_visible(False)\n",
411
+ " ax.get_yaxis().set_visible(False) \n",
412
+ " if i == n//2:\n",
413
+ " ax.set_title('Original images')\n",
414
+ " ax = plt.subplot(2, n, i + 1 + n)\n",
415
+ " plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray') \n",
416
+ " ax.get_xaxis().set_visible(False)\n",
417
+ " ax.get_yaxis().set_visible(False) \n",
418
+ " if i == n//2:\n",
419
+ " ax.set_title('Reconstructed images')\n",
420
+ " plt.show() "
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 34,
426
+ "id": "5f8b646b",
427
+ "metadata": {},
428
+ "outputs": [
429
+ {
430
+ "ename": "ValueError",
431
+ "evalue": "too many values to unpack (expected 2)",
432
+ "output_type": "error",
433
+ "traceback": [
434
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
435
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
436
+ "Input \u001b[0;32mIn [34]\u001b[0m, in \u001b[0;36m<cell line: 3>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m diz_loss \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_loss\u001b[39m\u001b[38;5;124m'\u001b[39m:[],\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m:[]}\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[0;32m----> 4\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mencoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdecoder\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43mloss_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43moptim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m val_loss \u001b[38;5;241m=\u001b[39m test_epoch(encoder,decoder,device,test_loader,loss_fn)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m EPOCH \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m train loss \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m val loss \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mformat(epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, num_epochs,train_loss,val_loss))\n",
437
+ "Input \u001b[0;32mIn [33]\u001b[0m, in \u001b[0;36mtrain_epoch\u001b[0;34m(encoder, decoder, device, dataloader, loss_fn, optimizer)\u001b[0m\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Iterate the dataloader (we do not need the label values, this is unsupervised learning)\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m image_batch, _ \u001b[38;5;129;01min\u001b[39;00m dataloader: \u001b[38;5;66;03m# with \"_\" we just ignore the labels (the second element of the dataloader tuple)\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# Move tensor to the proper device\u001b[39;00m\n\u001b[1;32m 10\u001b[0m image_batch \u001b[38;5;241m=\u001b[39m image_batch\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# Encode data\u001b[39;00m\n",
438
+ "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
439
+ ]
440
+ }
441
+ ],
442
+ "source": [
443
+ "num_epochs = 30\n",
444
+ "diz_loss = {'train_loss':[],'val_loss':[]}\n",
445
+ "for epoch in range(num_epochs):\n",
446
+ " train_loss =train_epoch(encoder,decoder,device,train_loader,loss_fn,optim)\n",
447
+ " val_loss = test_epoch(encoder,decoder,device,test_loader,loss_fn)\n",
448
+ " print('\\n EPOCH {}/{} \\t train loss {} \\t val loss {}'.format(epoch + 1, num_epochs,train_loss,val_loss))\n",
449
+ " diz_loss['train_loss'].append(train_loss)\n",
450
+ " diz_loss['val_loss'].append(val_loss)\n",
451
+ " plot_ae_outputs(encoder,decoder,n=10)"
452
+ ]
453
+ }
454
+ ],
455
+ "metadata": {
456
+ "kernelspec": {
457
+ "display_name": "Python 3 (ipykernel)",
458
+ "language": "python",
459
+ "name": "python3"
460
+ },
461
+ "language_info": {
462
+ "codemirror_mode": {
463
+ "name": "ipython",
464
+ "version": 3
465
+ },
466
+ "file_extension": ".py",
467
+ "mimetype": "text/x-python",
468
+ "name": "python",
469
+ "nbconvert_exporter": "python",
470
+ "pygments_lexer": "ipython3",
471
+ "version": "3.9.12"
472
+ }
473
+ },
474
+ "nbformat": 4,
475
+ "nbformat_minor": 5
476
+ }