abbylagar commited on
Commit
5b6d92b
1 Parent(s): 2926dae

Upload denoiseCIFAR100.py

Browse files
Files changed (1) hide show
  1. denoiseCIFAR100.py +110 -0
denoiseCIFAR100.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pytorch lightning module for denoising CIFAR100
3
+ """
4
+
5
+
6
+ #import functions
7
+
8
+ import numpy as np
9
+ from torch import nn
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange, reduce
13
+ from argparse import ArgumentParser
14
+ from pytorch_lightning import LightningModule, Trainer, Callback
15
+ from pytorch_lightning.loggers import WandbLogger
16
+ from torch.optim import Adam
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR
18
+ from encoder import Encoder
19
+ from decoder import Decoder
20
+
21
+
22
+ class DenoiseCIFAR100Model(LightningModule):
23
+
24
+ def __init__(self, feature_dim=256, lr=0.001, batch_size=64,
25
+ num_workers=2, max_epochs=30, **kwargs):
26
+ super().__init__()
27
+ self.save_hyperparameters()
28
+ self.encoder = Encoder(feature_dim=feature_dim)
29
+ self.decoder = Decoder(feature_dim=feature_dim)
30
+ self.loss = nn.MSELoss()
31
+
32
+ def forward(self, x):
33
+ h = self.encoder(x)
34
+ x_tilde = self.decoder(h)
35
+ return x_tilde
36
+
37
+ # this is called during fit()
38
+ def training_step(self, batch, batch_idx):
39
+ x_in, x = batch
40
+ x_tilde = self.forward(x_in)
41
+ loss = self.loss(x_tilde, x)
42
+ return {"loss": loss}
43
+
44
+ # calls to self.log() are recorded in wandb
45
+ def training_epoch_end(self, outputs):
46
+ avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
47
+ self.log("train_loss", avg_loss, on_epoch=True)
48
+
49
+ # this is called at the end of an epoch
50
+ def test_step(self, batch, batch_idx):
51
+ x_in, x = batch
52
+ x_tilde = self.forward(x_in)
53
+ loss = self.loss(x_tilde, x)
54
+ return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,}
55
+
56
+ # this is called at the end of all epochs
57
+ def test_epoch_end(self, outputs):
58
+ avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
59
+ self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
60
+
61
+ # validation is the same as test
62
+ def validation_step(self, batch, batch_idx):
63
+ return self.test_step(batch, batch_idx)
64
+
65
+ def validation_epoch_end(self, outputs):
66
+ return self.test_epoch_end(outputs)
67
+
68
+ # we use Adam optimizer
69
+ def configure_optimizers(self):
70
+ optimizer = Adam(self.parameters(), lr=self.hparams.lr)
71
+ # this decays the learning rate to 0 after max_epochs using cosine annealing
72
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
73
+ return [optimizer], [scheduler],
74
+
75
+ # this is called after model instatiation to initiliaze the datasets and dataloaders
76
+ def setup(self, stage=None):
77
+ self.train_dataloader()
78
+ self.test_dataloader()
79
+
80
+ # build train and test dataloaders using MNIST dataset
81
+ # we use simple ToTensor transform
82
+ def train_dataloader(self):
83
+ return torch.utils.data.DataLoader(
84
+ torchvision.datasets.CIFAR100(
85
+ "./data", train=True, download=True,
86
+ transform=torchvision.transforms.ToTensor()
87
+ ),
88
+ batch_size=self.hparams.batch_size,
89
+ shuffle=True,
90
+ num_workers=self.hparams.num_workers,
91
+ pin_memory=True,
92
+ collate_fn=noise_collate_fn
93
+ )
94
+
95
+ def test_dataloader(self):
96
+ return torch.utils.data.DataLoader(
97
+ torchvision.datasets.CIFAR100(
98
+ "./data", train=False, download=True,
99
+ transform=torchvision.transforms.ToTensor()
100
+ ),
101
+ batch_size=self.hparams.batch_size,
102
+ shuffle=False,
103
+ num_workers=self.hparams.num_workers,
104
+ pin_memory=True,
105
+ collate_fn=noise_collate_fn
106
+ )
107
+
108
+
109
+ def val_dataloader(self):
110
+ return self.test_dataloader()