""" https://github.com/marrrcin/pytorch-resnet-mnist/blob/master/pytorch-resnet-mnist.ipynb https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master/cifar10_models """ import os, sys from argparse import ArgumentParser from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint # from pytorch_lightning.accelerators import GPUAccelerator sys.path.append('/home/yiming/ContrastDebugger/Model-mnist') from data import MNISTData # from module import MNISTModule import pytorch_lightning as pl import torch # from torchmetrics import Accuracy from pytorch_lightning.metrics import Accuracy from cifar10_models.densenet import densenet121, densenet161, densenet169 from cifar10_models.googlenet import googlenet from cifar10_models.inception import inception_v3 from cifar10_models.mobilenetv2 import mobilenet_v2 from cifar10_models.resnet import resnet18, resnet34, resnet50 from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from cifar10_models.mlp import mlp3 from cifar10_models.convnet import convnet from schduler import WarmupCosineLR from torch import nn # from pytorch_lightning.core.decorators import auto_move_data from torchvision.transforms import ToTensor from torchvision.datasets import MNIST from torch.utils.data import DataLoader import json import os parser = ArgumentParser() # PROGRAM level args parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--test_phase", type=int, default=0, choices=[0, 1]) parser.add_argument("--dev", type=int, default=0, choices=[0, 1]) # TRAINER args parser.add_argument("--classifier", type=str, default="resnet18") parser.add_argument("--precision", type=int, default=32, choices=[16, 32]) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--max_epochs", type=int, default=20) parser.add_argument("--num_workers", type=int, default=2) parser.add_argument("--gpu_id", type=str, default="0") parser.add_argument("--learning_rate", type=float, default=5e-3) parser.add_argument("--weight_decay", type=float, default=1e-2) parser.add_argument("--filepath", type=str, default="Model") parser.add_argument("--period", type=int, default=1) parser.add_argument("--save_top_k", type=int, default=-1) args = parser.parse_args(args=[]) def main(args): seed_everything(0) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id checkpoint = ModelCheckpoint( # dirpath=os.path.join(args.filepath, args.classifier), # filename="{epoch:03d}", filepath=os.path.join(args.filepath, args.classifier, "{epoch:03d}"), monitor="acc/val", mode="max", # save_last=False, period=args.period, save_top_k=args.save_top_k, save_weights_only=True, ) trainer = Trainer( fast_dev_run=bool(args.dev), gpus=args.gpu_id, deterministic=True, weights_summary=None, log_every_n_steps=1, max_epochs=args.max_epochs, checkpoint_callback=checkpoint, precision=args.precision, ) model = MNISTModule(args) # data = MNISTData(args) # trainloader = data.train_dataloader() # data.save_train_data(trainloader, args.filepath) # testloader = data.test_dataloader() # data.save_test_data(testloader, args.filepath) # if bool(args.test_phase): # trainer.test(model, data.test_dataloader()) # else: # trainer.fit(model, data) # trainer.test() all_classifiers = { "vgg11_bn": vgg11_bn(), "vgg13_bn": vgg13_bn(), "vgg16_bn": vgg16_bn(), "vgg19_bn": vgg19_bn(), "resnet18": resnet18(), "resnet34": resnet34(), "resnet50": resnet50(), "densenet121": densenet121(), "densenet161": densenet161(), "densenet169": densenet169(), "mobilenet_v2": mobilenet_v2(), "googlenet": googlenet(), "inception_v3": inception_v3(), "mlp":mlp3(), "convnet":convnet() } class MNISTModule(pl.LightningModule): def __init__(self, my_hparams): super().__init__() self.my_hparams = my_hparams self.criterion = torch.nn.CrossEntropyLoss() self.accuracy = Accuracy() self.model = all_classifiers[self.my_hparams.classifier] if self.my_hparams.classifier not in ["mlp", "convnet"]: self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) def forward(self, batch): images, labels = batch predictions = self.model(images) loss = self.criterion(predictions, labels) accuracy = self.accuracy(predictions, labels) return loss, accuracy * 100 # @auto_move_data # def forward(self, x): # return self.model(x) def training_step(self, batch, batch_nb): loss, accuracy = self.forward(batch) self.log("loss/train", loss) self.log("acc/train", accuracy) return loss def validation_step(self, batch, batch_nb): loss, accuracy = self.forward(batch) self.log("loss/val", loss) self.log("acc/val", accuracy) def test_step(self, batch, batch_nb): loss, accuracy = self.forward(batch) self.log("acc/test", accuracy) def train_dataloader(self): transform = ToTensor() dataset = MNIST("mnist", train=True, download=True, transform=transform) dataloader = DataLoader( dataset, batch_size=self.my_hparams.batch_size, num_workers=self.my_hparams.num_workers, shuffle=True, ) return dataloader def val_dataloader(self): transform = ToTensor() dataset = MNIST("mnist", train=False, download=True, transform=transform) dataloader = DataLoader( dataset, batch_size=self.my_hparams.batch_size, num_workers=self.my_hparams.num_workers, drop_last=True, pin_memory=True, ) return dataloader def test_dataloader(self): return self.val_dataloader() def configure_optimizers(self): optimizer = torch.optim.SGD( self.model.parameters(), lr=self.my_hparams.learning_rate, weight_decay=self.my_hparams.weight_decay, momentum=0.9, nesterov=True, ) total_steps = self.my_hparams.max_epochs * len(self.train_dataloader()) scheduler = { "scheduler": WarmupCosineLR( optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps ), "interval": "step", "name": "learning_rate", } return [optimizer], [scheduler] def on_train_epoch_end(self, epoch_output): epoch = self.trainer.current_epoch state_dict = self.model.state_dict() save_dir = "/home/yiming/EXP/mnist_resnet18/Model/Epoch_" + str(self.current_epoch + 1) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, "subject_model.pth") torch.save(state_dict, save_path) main(args)