SalazarPevelll
add training dynamic
88ebb5a
"""
https://github.com/marrrcin/pytorch-resnet-mnist/blob/master/pytorch-resnet-mnist.ipynb
https://github.com/huyvnphan/PyTorch_CIFAR10/tree/master/cifar10_models
"""
import os, sys
from argparse import ArgumentParser
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.accelerators import GPUAccelerator
sys.path.append('/home/yiming/ContrastDebugger/Model-mnist')
from data import MNISTData
# from module import MNISTModule
import pytorch_lightning as pl
import torch
# from torchmetrics import Accuracy
from pytorch_lightning.metrics import Accuracy
from cifar10_models.densenet import densenet121, densenet161, densenet169
from cifar10_models.googlenet import googlenet
from cifar10_models.inception import inception_v3
from cifar10_models.mobilenetv2 import mobilenet_v2
from cifar10_models.resnet import resnet18, resnet34, resnet50
from cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from cifar10_models.mlp import mlp3
from cifar10_models.convnet import convnet
from schduler import WarmupCosineLR
from torch import nn
# from pytorch_lightning.core.decorators import auto_move_data
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import json
import os
parser = ArgumentParser()
# PROGRAM level args
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--test_phase", type=int, default=0, choices=[0, 1])
parser.add_argument("--dev", type=int, default=0, choices=[0, 1])
# TRAINER args
parser.add_argument("--classifier", type=str, default="resnet18")
parser.add_argument("--precision", type=int, default=32, choices=[16, 32])
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--max_epochs", type=int, default=20)
parser.add_argument("--num_workers", type=int, default=2)
parser.add_argument("--gpu_id", type=str, default="0")
parser.add_argument("--learning_rate", type=float, default=5e-3)
parser.add_argument("--weight_decay", type=float, default=1e-2)
parser.add_argument("--filepath", type=str, default="Model")
parser.add_argument("--period", type=int, default=1)
parser.add_argument("--save_top_k", type=int, default=-1)
args = parser.parse_args(args=[])
def main(args):
seed_everything(0)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
checkpoint = ModelCheckpoint(
# dirpath=os.path.join(args.filepath, args.classifier),
# filename="{epoch:03d}",
filepath=os.path.join(args.filepath, args.classifier, "{epoch:03d}"),
monitor="acc/val",
mode="max",
# save_last=False,
period=args.period,
save_top_k=args.save_top_k,
save_weights_only=True,
)
trainer = Trainer(
fast_dev_run=bool(args.dev),
gpus=args.gpu_id,
deterministic=True,
weights_summary=None,
log_every_n_steps=1,
max_epochs=args.max_epochs,
checkpoint_callback=checkpoint,
precision=args.precision,
)
model = MNISTModule(args)
# data = MNISTData(args)
# trainloader = data.train_dataloader()
# data.save_train_data(trainloader, args.filepath)
# testloader = data.test_dataloader()
# data.save_test_data(testloader, args.filepath)
# if bool(args.test_phase):
# trainer.test(model, data.test_dataloader())
# else:
# trainer.fit(model, data)
# trainer.test()
all_classifiers = {
"vgg11_bn": vgg11_bn(),
"vgg13_bn": vgg13_bn(),
"vgg16_bn": vgg16_bn(),
"vgg19_bn": vgg19_bn(),
"resnet18": resnet18(),
"resnet34": resnet34(),
"resnet50": resnet50(),
"densenet121": densenet121(),
"densenet161": densenet161(),
"densenet169": densenet169(),
"mobilenet_v2": mobilenet_v2(),
"googlenet": googlenet(),
"inception_v3": inception_v3(),
"mlp":mlp3(),
"convnet":convnet()
}
class MNISTModule(pl.LightningModule):
def __init__(self, my_hparams):
super().__init__()
self.my_hparams = my_hparams
self.criterion = torch.nn.CrossEntropyLoss()
self.accuracy = Accuracy()
self.model = all_classifiers[self.my_hparams.classifier]
if self.my_hparams.classifier not in ["mlp", "convnet"]:
self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
def forward(self, batch):
images, labels = batch
predictions = self.model(images)
loss = self.criterion(predictions, labels)
accuracy = self.accuracy(predictions, labels)
return loss, accuracy * 100
# @auto_move_data
# def forward(self, x):
# return self.model(x)
def training_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("loss/train", loss)
self.log("acc/train", accuracy)
return loss
def validation_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("loss/val", loss)
self.log("acc/val", accuracy)
def test_step(self, batch, batch_nb):
loss, accuracy = self.forward(batch)
self.log("acc/test", accuracy)
def train_dataloader(self):
transform = ToTensor()
dataset = MNIST("mnist", train=True, download=True, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=self.my_hparams.batch_size,
num_workers=self.my_hparams.num_workers,
shuffle=True,
)
return dataloader
def val_dataloader(self):
transform = ToTensor()
dataset = MNIST("mnist", train=False, download=True, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=self.my_hparams.batch_size,
num_workers=self.my_hparams.num_workers,
drop_last=True,
pin_memory=True,
)
return dataloader
def test_dataloader(self):
return self.val_dataloader()
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.my_hparams.learning_rate,
weight_decay=self.my_hparams.weight_decay,
momentum=0.9,
nesterov=True,
)
total_steps = self.my_hparams.max_epochs * len(self.train_dataloader())
scheduler = {
"scheduler": WarmupCosineLR(
optimizer, warmup_epochs=total_steps * 0.3, max_epochs=total_steps
),
"interval": "step",
"name": "learning_rate",
}
return [optimizer], [scheduler]
def on_train_epoch_end(self, epoch_output):
epoch = self.trainer.current_epoch
state_dict = self.model.state_dict()
save_dir = "/home/yiming/EXP/mnist_resnet18/Model/Epoch_" + str(self.current_epoch + 1)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "subject_model.pth")
torch.save(state_dict, save_path)
main(args)