|
import yaml |
|
from pathlib import Path |
|
import click |
|
import torch |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
|
from models.mobilevit import MobileVIT |
|
from data.data_preprocessing import FluorescentNeuronalDataModule |
|
|
|
CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml" |
|
DATA_DIR = "data/raw/" |
|
LOGS_DIR = "reports/logs/FluorescentMobileVIT" |
|
MODEL_DIR = "models/FluorescentMobileVIT" |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
DEVICE = torch.device("mps:0") |
|
ACCELERATOR = "mps" |
|
elif torch.cuda.is_available(): |
|
DEVICE = torch.device("cuda") |
|
ACCELERATOR = "gpu" |
|
else: |
|
DEVICE = torch.device("cpu") |
|
ACCELERATOR = "cpu" |
|
|
|
|
|
@click.command() |
|
@click.option( |
|
"--data_dir", |
|
type=click.Path(exists=True, file_okay=True, path_type=Path), |
|
default=DATA_DIR, |
|
) |
|
@click.option( |
|
"--config_file", |
|
type=click.Path(exists=True, file_okay=True, path_type=Path), |
|
default=CONFIG_FILE, |
|
) |
|
def train_model(data_dir, config_file): |
|
|
|
with open(config_file, "r") as file: |
|
best_params = yaml.safe_load(file) |
|
|
|
model = MobileVIT( |
|
learning_rate=best_params["learning_rate"], |
|
weight_decay=best_params["weight_decay"], |
|
) |
|
|
|
model_checkpoint_cb = ModelCheckpoint( |
|
save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss" |
|
) |
|
logger = TensorBoardLogger(save_dir=LOGS_DIR) |
|
|
|
|
|
trainer = pl.Trainer( |
|
logger=logger, |
|
devices=1, |
|
accelerator=ACCELERATOR, |
|
precision=16, |
|
max_epochs=100, |
|
log_every_n_steps=20, |
|
callbacks=[model_checkpoint_cb], |
|
) |
|
data_module = FluorescentNeuronalDataModule( |
|
data_dir=data_dir, batch_size=best_params["batch_size"] |
|
) |
|
trainer.fit(model=model, datamodule=data_module) |
|
trainer.test(model=model, datamodule=data_module) |
|
click.echo("\n\n==========The Training has Finished!==========") |
|
|