Denoising_CIFAR100 / denoiseCIFAR100.py
abbylagar's picture
Upload denoiseCIFAR100.py
5b6d92b
raw
history blame
3.72 kB
"""
pytorch lightning module for denoising CIFAR100
"""
#import functions
import numpy as np
from torch import nn
import torch
import torchvision
from einops import rearrange, reduce
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from encoder import Encoder
from decoder import Decoder
class DenoiseCIFAR100Model(LightningModule):
def __init__(self, feature_dim=256, lr=0.001, batch_size=64,
num_workers=2, max_epochs=30, **kwargs):
super().__init__()
self.save_hyperparameters()
self.encoder = Encoder(feature_dim=feature_dim)
self.decoder = Decoder(feature_dim=feature_dim)
self.loss = nn.MSELoss()
def forward(self, x):
h = self.encoder(x)
x_tilde = self.decoder(h)
return x_tilde
# this is called during fit()
def training_step(self, batch, batch_idx):
x_in, x = batch
x_tilde = self.forward(x_in)
loss = self.loss(x_tilde, x)
return {"loss": loss}
# calls to self.log() are recorded in wandb
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("train_loss", avg_loss, on_epoch=True)
# this is called at the end of an epoch
def test_step(self, batch, batch_idx):
x_in, x = batch
x_tilde = self.forward(x_in)
loss = self.loss(x_tilde, x)
return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,}
# this is called at the end of all epochs
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
# validation is the same as test
def validation_step(self, batch, batch_idx):
return self.test_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
return self.test_epoch_end(outputs)
# we use Adam optimizer
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.lr)
# this decays the learning rate to 0 after max_epochs using cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
return [optimizer], [scheduler],
# this is called after model instatiation to initiliaze the datasets and dataloaders
def setup(self, stage=None):
self.train_dataloader()
self.test_dataloader()
# build train and test dataloaders using MNIST dataset
# we use simple ToTensor transform
def train_dataloader(self):
return torch.utils.data.DataLoader(
torchvision.datasets.CIFAR100(
"./data", train=True, download=True,
transform=torchvision.transforms.ToTensor()
),
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=True,
collate_fn=noise_collate_fn
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
torchvision.datasets.CIFAR100(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor()
),
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=True,
collate_fn=noise_collate_fn
)
def val_dataloader(self):
return self.test_dataloader()