DmitriiKhizbullin's picture
Reorganized files
8afb176
import datetime
from argparse import ArgumentParser
import torch
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelSummary
from src.trainer import ViTLightningModule
def main():
""" Neural network trainer entry point. """
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
parser.add_argument('--tag', action='store', type=str,
help='Extra suffix to put on the artefact dir name')
parser.add_argument('--debug', action='store_true',
help="Dummy training cycle for testing purposes")
parser.add_argument('--convert-checkpoint', action='store', type=str,
help='Convert a checkpoint from training to pickle-independent '
'predictor-compatible directory')
args = parser.parse_args()
torch.set_float32_matmul_precision('high') # for V100/A100
if args.convert_checkpoint is not None:
print("Converting checkpoint", args.convert_checkpoint)
checkpoint = torch.load(args.convert_checkpoint, map_location="cpu")
print(list(checkpoint.keys()))
model = ViTLightningModule.load_from_checkpoint(
args.convert_checkpoint,
map_location="cpu",
hparams_file="tmp_ckpt_deleteme.yaml")
model.save_checkpoint_dk("tmp_checkp_path_deleteme")
print("Saved checkpoint. Done.")
else:
print("Start training")
fast_dev_run = True if args.debug == True else False
model = ViTLightningModule(fast_dev_run)
datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
art_dir_name = (f"{datetime_str}" +
(f"_{args.tag}" if args.tag is not None else ""))
logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name)
trainer = Trainer(
logger=logger,
benchmark=True,
devices="auto",
accelerator="auto",
max_epochs=-1,
callbacks=[
ModelSummary(max_depth=-1),
],
fast_dev_run=fast_dev_run,
log_every_n_steps=10,
)
trainer.fit(
model,
train_dataloaders=model._train_dataloader,
val_dataloaders=model._val_dataloader,
)
print("Training done")
if __name__ == "__main__":
main()