|
import lightning as pl |
|
from lightning.pytorch.callbacks import ( |
|
ModelCheckpoint, |
|
EarlyStopping, |
|
LearningRateMonitor, |
|
RichProgressBar, |
|
) |
|
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger |
|
from lightning.pytorch.callbacks import ModelSummary |
|
from src.dataloader import MNISTDataModule |
|
from src.model import LitEfficientNet |
|
from loguru import logger |
|
import os |
|
from src.utils.aws_s3_services import S3Handler |
|
|
|
|
|
os.makedirs("logs", exist_ok=True) |
|
|
|
|
|
logger.add("logs/training.log", rotation="1 MB", level="INFO") |
|
|
|
|
|
def main(): |
|
""" |
|
Main training loop for the model with advanced configuration (CPU training). |
|
""" |
|
|
|
logger.info("Setting up data module...") |
|
data_module = MNISTDataModule(batch_size=256) |
|
|
|
|
|
logger.info("Setting up model...") |
|
model = LitEfficientNet(model_name="tf_efficientnet_lite0", num_classes=10, lr=1e-3) |
|
logger.info(model) |
|
|
|
|
|
logger.info("Setting up callbacks...") |
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val_acc", |
|
dirpath="checkpoints/", |
|
filename="best_model", |
|
save_top_k=1, |
|
mode="max", |
|
auto_insert_metric_name=False, |
|
verbose=True, |
|
save_last=True, |
|
enable_version_counter=False, |
|
) |
|
early_stopping_callback = EarlyStopping( |
|
monitor="val_acc", |
|
patience=5, |
|
mode="max", |
|
verbose=True, |
|
) |
|
lr_monitor = LearningRateMonitor(logging_interval="epoch") |
|
rich_progress = RichProgressBar() |
|
model_summary = ModelSummary( |
|
max_depth=1 |
|
) |
|
|
|
|
|
logger.info("Setting up loggers...") |
|
csv_logger = CSVLogger("logs/", name="mnist_csv") |
|
tb_logger = TensorBoardLogger("logs/", name="mnist_tb") |
|
|
|
|
|
logger.info("Setting up trainer...") |
|
trainer = pl.Trainer( |
|
max_epochs=2, |
|
callbacks=[ |
|
checkpoint_callback, |
|
early_stopping_callback, |
|
lr_monitor, |
|
rich_progress, |
|
model_summary, |
|
], |
|
logger=[csv_logger, tb_logger], |
|
deterministic=True, |
|
accelerator="auto", |
|
devices="auto", |
|
) |
|
|
|
|
|
logger.info("Training the model...") |
|
trainer.fit(model, datamodule=data_module) |
|
|
|
|
|
logger.info("Testing the model...") |
|
data_module.setup(stage="test") |
|
trainer.test(model, datamodule=data_module) |
|
|
|
|
|
with open("checkpoints/train_done.flag", "w") as f: |
|
f.write("Training done.") |
|
|
|
|
|
s3_handler = S3Handler(bucket_name="deep-bucket-s3") |
|
s3_handler.upload_folder( |
|
"checkpoints", |
|
"checkpoints_test", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|