|
"""
|
|
https://github.com/marrrcin/pytorch-resnet-mnist/blob/master/pytorch-resnet-mnist.ipynb
|
|
https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master/cifar10_models
|
|
"""
|
|
import os, sys
|
|
from argparse import ArgumentParser
|
|
|
|
from pytorch_lightning import Trainer, seed_everything
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
|
|
sys.path.append('/home/yiming/ContrastDebugger/Model-mnist')
|
|
from data import MNISTData
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
|
|
from pytorch_lightning.metrics import Accuracy
|
|
|
|
from cifar10_models.densenet import densenet121, densenet161, densenet169
|
|
from cifar10_models.googlenet import googlenet
|
|
from cifar10_models.inception import inception_v3
|
|
from cifar10_models.mobilenetv2 import mobilenet_v2
|
|
from cifar10_models.resnet import resnet18, resnet34, resnet50
|
|
from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
|
|
from cifar10_models.mlp import mlp3
|
|
from cifar10_models.convnet import convnet
|
|
from schduler import WarmupCosineLR
|
|
from torch import nn
|
|
|
|
|
|
from torchvision.transforms import ToTensor
|
|
from torchvision.datasets import MNIST
|
|
from torch.utils.data import DataLoader
|
|
import json
|
|
import os
|
|
|
|
parser = ArgumentParser()
|
|
|
|
|
|
parser.add_argument("--data_dir", type=str, default="data")
|
|
parser.add_argument("--test_phase", type=int, default=0, choices=[0, 1])
|
|
parser.add_argument("--dev", type=int, default=0, choices=[0, 1])
|
|
|
|
|
|
parser.add_argument("--classifier", type=str, default="resnet18")
|
|
parser.add_argument("--precision", type=int, default=32, choices=[16, 32])
|
|
parser.add_argument("--batch_size", type=int, default=128)
|
|
parser.add_argument("--max_epochs", type=int, default=20)
|
|
parser.add_argument("--num_workers", type=int, default=2)
|
|
parser.add_argument("--gpu_id", type=str, default="0")
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-3)
|
|
parser.add_argument("--weight_decay", type=float, default=1e-2)
|
|
parser.add_argument("--filepath", type=str, default="Model")
|
|
parser.add_argument("--period", type=int, default=1)
|
|
parser.add_argument("--save_top_k", type=int, default=-1)
|
|
|
|
args = parser.parse_args(args=[])
|
|
|
|
def main(args):
|
|
|
|
seed_everything(0)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
|
|
|
|
checkpoint = ModelCheckpoint(
|
|
|
|
|
|
filepath=os.path.join(args.filepath, args.classifier, "{epoch:03d}"),
|
|
monitor="acc/val",
|
|
mode="max",
|
|
|
|
period=args.period,
|
|
save_top_k=args.save_top_k,
|
|
save_weights_only=True,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
fast_dev_run=bool(args.dev),
|
|
gpus=args.gpu_id,
|
|
deterministic=True,
|
|
weights_summary=None,
|
|
log_every_n_steps=1,
|
|
max_epochs=args.max_epochs,
|
|
checkpoint_callback=checkpoint,
|
|
precision=args.precision,
|
|
|
|
)
|
|
|
|
model = MNISTModule(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_classifiers = {
|
|
"vgg11_bn": vgg11_bn(),
|
|
"vgg13_bn": vgg13_bn(),
|
|
"vgg16_bn": vgg16_bn(),
|
|
"vgg19_bn": vgg19_bn(),
|
|
"resnet18": resnet18(),
|
|
"resnet34": resnet34(),
|
|
"resnet50": resnet50(),
|
|
"densenet121": densenet121(),
|
|
"densenet161": densenet161(),
|
|
"densenet169": densenet169(),
|
|
"mobilenet_v2": mobilenet_v2(),
|
|
"googlenet": googlenet(),
|
|
"inception_v3": inception_v3(),
|
|
"mlp":mlp3(),
|
|
"convnet":convnet()
|
|
}
|
|
|
|
|
|
class MNISTModule(pl.LightningModule):
|
|
def __init__(self, my_hparams):
|
|
super().__init__()
|
|
self.my_hparams = my_hparams
|
|
|
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
self.accuracy = Accuracy()
|
|
|
|
self.model = all_classifiers[self.my_hparams.classifier]
|
|
if self.my_hparams.classifier not in ["mlp", "convnet"]:
|
|
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
|
|
|
def forward(self, batch):
|
|
images, labels = batch
|
|
predictions = self.model(images)
|
|
loss = self.criterion(predictions, labels)
|
|
accuracy = self.accuracy(predictions, labels)
|
|
return loss, accuracy * 100
|
|
|
|
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_nb):
|
|
loss, accuracy = self.forward(batch)
|
|
self.log("loss/train", loss)
|
|
self.log("acc/train", accuracy)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_nb):
|
|
loss, accuracy = self.forward(batch)
|
|
self.log("loss/val", loss)
|
|
self.log("acc/val", accuracy)
|
|
|
|
def test_step(self, batch, batch_nb):
|
|
loss, accuracy = self.forward(batch)
|
|
self.log("acc/test", accuracy)
|
|
|
|
def train_dataloader(self):
|
|
transform = ToTensor()
|
|
dataset = MNIST("mnist", train=True, download=True, transform=transform)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.my_hparams.batch_size,
|
|
num_workers=self.my_hparams.num_workers,
|
|
shuffle=True,
|
|
)
|
|
return dataloader
|
|
|
|
def val_dataloader(self):
|
|
transform = ToTensor()
|
|
dataset = MNIST("mnist", train=False, download=True, transform=transform)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.my_hparams.batch_size,
|
|
num_workers=self.my_hparams.num_workers,
|
|
drop_last=True,
|
|
pin_memory=True,
|
|
)
|
|
return dataloader
|
|
|
|
def test_dataloader(self):
|
|
return self.val_dataloader()
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.SGD(
|
|
self.model.parameters(),
|
|
lr=self.my_hparams.learning_rate,
|
|
weight_decay=self.my_hparams.weight_decay,
|
|
momentum=0.9,
|
|
nesterov=True,
|
|
)
|
|
total_steps = self.my_hparams.max_epochs * len(self.train_dataloader())
|
|
scheduler = {
|
|
"scheduler": WarmupCosineLR(
|
|
optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps
|
|
),
|
|
"interval": "step",
|
|
"name": "learning_rate",
|
|
}
|
|
return [optimizer], [scheduler]
|
|
|
|
def on_train_epoch_end(self, epoch_output):
|
|
epoch = self.trainer.current_epoch
|
|
state_dict = self.model.state_dict()
|
|
save_dir = "/home/yiming/EXP/mnist_resnet18/Model/Epoch_" + str(self.current_epoch + 1)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
save_path = os.path.join(save_dir, "subject_model.pth")
|
|
torch.save(state_dict, save_path)
|
|
|
|
main(args) |