""" pytorch lightning module for denoising CIFAR100 """ #import functions import numpy as np from torch import nn import torch import torchvision from einops import rearrange, reduce from argparse import ArgumentParser from pytorch_lightning import LightningModule, Trainer, Callback from pytorch_lightning.loggers import WandbLogger from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR from encoder import Encoder from decoder import Decoder class DenoiseCIFAR100Model(LightningModule): def __init__(self, feature_dim=256, lr=0.001, batch_size=64, num_workers=2, max_epochs=30, **kwargs): super().__init__() self.save_hyperparameters() self.encoder = Encoder(feature_dim=feature_dim) self.decoder = Decoder(feature_dim=feature_dim) self.loss = nn.MSELoss() def forward(self, x): h = self.encoder(x) x_tilde = self.decoder(h) return x_tilde # this is called during fit() def training_step(self, batch, batch_idx): x_in, x = batch x_tilde = self.forward(x_in) loss = self.loss(x_tilde, x) return {"loss": loss} # calls to self.log() are recorded in wandb def training_epoch_end(self, outputs): avg_loss = torch.stack([x["loss"] for x in outputs]).mean() self.log("train_loss", avg_loss, on_epoch=True) # this is called at the end of an epoch def test_step(self, batch, batch_idx): x_in, x = batch x_tilde = self.forward(x_in) loss = self.loss(x_tilde, x) return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,} # this is called at the end of all epochs def test_epoch_end(self, outputs): avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean() self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True) # validation is the same as test def validation_step(self, batch, batch_idx): return self.test_step(batch, batch_idx) def validation_epoch_end(self, outputs): return self.test_epoch_end(outputs) # we use Adam optimizer def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.lr) # this decays the learning rate to 0 after max_epochs using cosine annealing scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs) return [optimizer], [scheduler], # this is called after model instatiation to initiliaze the datasets and dataloaders def setup(self, stage=None): self.train_dataloader() self.test_dataloader() # build train and test dataloaders using MNIST dataset # we use simple ToTensor transform def train_dataloader(self): return torch.utils.data.DataLoader( torchvision.datasets.CIFAR100( "./data", train=True, download=True, transform=torchvision.transforms.ToTensor() ), batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=noise_collate_fn ) def test_dataloader(self): return torch.utils.data.DataLoader( torchvision.datasets.CIFAR100( "./data", train=False, download=True, transform=torchvision.transforms.ToTensor() ), batch_size=self.hparams.batch_size, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True, collate_fn=noise_collate_fn ) def val_dataloader(self): return self.test_dataloader()