import lightning as pl from lightning.pytorch.callbacks import ( ModelCheckpoint, EarlyStopping, LearningRateMonitor, RichProgressBar, ) from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.callbacks import ModelSummary from src.dataloader import MNISTDataModule from src.model import LitEfficientNet from loguru import logger import os from src.utils.aws_s3_services import S3Handler # Ensure the logs directory exists os.makedirs("logs", exist_ok=True) # Configure Loguru for logging logger.add("logs/training.log", rotation="1 MB", level="INFO") def main(): """ Main training loop for the model with advanced configuration (CPU training). """ # Data Module logger.info("Setting up data module...") data_module = MNISTDataModule(batch_size=256) # Model logger.info("Setting up model...") model = LitEfficientNet(model_name="tf_efficientnet_lite0", num_classes=10, lr=1e-3) logger.info(model) # Callbacks logger.info("Setting up callbacks...") checkpoint_callback = ModelCheckpoint( monitor="val_acc", dirpath="checkpoints/", filename="best_model", save_top_k=1, mode="max", auto_insert_metric_name=False, verbose=True, save_last=True, enable_version_counter=False, ) early_stopping_callback = EarlyStopping( monitor="val_acc", patience=5, # Extended patience for advanced models mode="max", verbose=True, ) lr_monitor = LearningRateMonitor(logging_interval="epoch") # Log learning rate rich_progress = RichProgressBar() model_summary = ModelSummary( max_depth=1 ) # Show only the first level of model layers # Loggers logger.info("Setting up loggers...") csv_logger = CSVLogger("logs/", name="mnist_csv") tb_logger = TensorBoardLogger("logs/", name="mnist_tb") # Trainer Configuration for CPU logger.info("Setting up trainer...") trainer = pl.Trainer( max_epochs=2, callbacks=[ checkpoint_callback, early_stopping_callback, lr_monitor, rich_progress, model_summary, ], logger=[csv_logger, tb_logger], deterministic=True, accelerator="auto", devices="auto", ) # Train the model logger.info("Training the model...") trainer.fit(model, datamodule=data_module) # Test the model logger.info("Testing the model...") data_module.setup(stage="test") trainer.test(model, datamodule=data_module) # write a checkpoints/train_done.flag with open("checkpoints/train_done.flag", "w") as f: f.write("Training done.") # upload checkpoints to S3 s3_handler = S3Handler(bucket_name="deep-bucket-s3") s3_handler.upload_folder( "checkpoints", "checkpoints_test", ) if __name__ == "__main__": main()