Spaces:
Runtime error
Runtime error
""" | |
pytorch lightning module for denoising CIFAR100 | |
""" | |
#import functions | |
import numpy as np | |
from torch import nn | |
import torch | |
import torchvision | |
from einops import rearrange, reduce | |
from argparse import ArgumentParser | |
from pytorch_lightning import LightningModule, Trainer, Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from torch.optim import Adam | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
from encoder import Encoder | |
from decoder import Decoder | |
class DenoiseCIFAR100Model(LightningModule): | |
def __init__(self, feature_dim=256, lr=0.001, batch_size=64, | |
num_workers=2, max_epochs=30, **kwargs): | |
super().__init__() | |
self.save_hyperparameters() | |
self.encoder = Encoder(feature_dim=feature_dim) | |
self.decoder = Decoder(feature_dim=feature_dim) | |
self.loss = nn.MSELoss() | |
def forward(self, x): | |
h = self.encoder(x) | |
x_tilde = self.decoder(h) | |
return x_tilde | |
# this is called during fit() | |
def training_step(self, batch, batch_idx): | |
x_in, x = batch | |
x_tilde = self.forward(x_in) | |
loss = self.loss(x_tilde, x) | |
return {"loss": loss} | |
# calls to self.log() are recorded in wandb | |
def training_epoch_end(self, outputs): | |
avg_loss = torch.stack([x["loss"] for x in outputs]).mean() | |
self.log("train_loss", avg_loss, on_epoch=True) | |
# this is called at the end of an epoch | |
def test_step(self, batch, batch_idx): | |
x_in, x = batch | |
x_tilde = self.forward(x_in) | |
loss = self.loss(x_tilde, x) | |
return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,} | |
# this is called at the end of all epochs | |
def test_epoch_end(self, outputs): | |
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean() | |
self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True) | |
# validation is the same as test | |
def validation_step(self, batch, batch_idx): | |
return self.test_step(batch, batch_idx) | |
def validation_epoch_end(self, outputs): | |
return self.test_epoch_end(outputs) | |
# we use Adam optimizer | |
def configure_optimizers(self): | |
optimizer = Adam(self.parameters(), lr=self.hparams.lr) | |
# this decays the learning rate to 0 after max_epochs using cosine annealing | |
scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs) | |
return [optimizer], [scheduler], | |
# this is called after model instatiation to initiliaze the datasets and dataloaders | |
def setup(self, stage=None): | |
self.train_dataloader() | |
self.test_dataloader() | |
# build train and test dataloaders using MNIST dataset | |
# we use simple ToTensor transform | |
def train_dataloader(self): | |
return torch.utils.data.DataLoader( | |
torchvision.datasets.CIFAR100( | |
"./data", train=True, download=True, | |
transform=torchvision.transforms.ToTensor() | |
), | |
batch_size=self.hparams.batch_size, | |
shuffle=True, | |
num_workers=self.hparams.num_workers, | |
pin_memory=True, | |
collate_fn=noise_collate_fn | |
) | |
def test_dataloader(self): | |
return torch.utils.data.DataLoader( | |
torchvision.datasets.CIFAR100( | |
"./data", train=False, download=True, | |
transform=torchvision.transforms.ToTensor() | |
), | |
batch_size=self.hparams.batch_size, | |
shuffle=False, | |
num_workers=self.hparams.num_workers, | |
pin_memory=True, | |
collate_fn=noise_collate_fn | |
) | |
def val_dataloader(self): | |
return self.test_dataloader() |