Spaces:
Runtime error
Runtime error
File size: 3,719 Bytes
5b6d92b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
"""
pytorch lightning module for denoising CIFAR100
"""
#import functions
import numpy as np
from torch import nn
import torch
import torchvision
from einops import rearrange, reduce
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from encoder import Encoder
from decoder import Decoder
class DenoiseCIFAR100Model(LightningModule):
def __init__(self, feature_dim=256, lr=0.001, batch_size=64,
num_workers=2, max_epochs=30, **kwargs):
super().__init__()
self.save_hyperparameters()
self.encoder = Encoder(feature_dim=feature_dim)
self.decoder = Decoder(feature_dim=feature_dim)
self.loss = nn.MSELoss()
def forward(self, x):
h = self.encoder(x)
x_tilde = self.decoder(h)
return x_tilde
# this is called during fit()
def training_step(self, batch, batch_idx):
x_in, x = batch
x_tilde = self.forward(x_in)
loss = self.loss(x_tilde, x)
return {"loss": loss}
# calls to self.log() are recorded in wandb
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log("train_loss", avg_loss, on_epoch=True)
# this is called at the end of an epoch
def test_step(self, batch, batch_idx):
x_in, x = batch
x_tilde = self.forward(x_in)
loss = self.loss(x_tilde, x)
return {"x_in" : x_in, "x": x, "x_tilde" : x_tilde, "test_loss" : loss,}
# this is called at the end of all epochs
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
# validation is the same as test
def validation_step(self, batch, batch_idx):
return self.test_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
return self.test_epoch_end(outputs)
# we use Adam optimizer
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.lr)
# this decays the learning rate to 0 after max_epochs using cosine annealing
scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
return [optimizer], [scheduler],
# this is called after model instatiation to initiliaze the datasets and dataloaders
def setup(self, stage=None):
self.train_dataloader()
self.test_dataloader()
# build train and test dataloaders using MNIST dataset
# we use simple ToTensor transform
def train_dataloader(self):
return torch.utils.data.DataLoader(
torchvision.datasets.CIFAR100(
"./data", train=True, download=True,
transform=torchvision.transforms.ToTensor()
),
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=True,
collate_fn=noise_collate_fn
)
def test_dataloader(self):
return torch.utils.data.DataLoader(
torchvision.datasets.CIFAR100(
"./data", train=False, download=True,
transform=torchvision.transforms.ToTensor()
),
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=self.hparams.num_workers,
pin_memory=True,
collate_fn=noise_collate_fn
)
def val_dataloader(self):
return self.test_dataloader() |