File size: 2,971 Bytes
c3d82b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import lightning as pl
from lightning.pytorch.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
)
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelSummary
from src.dataloader import MNISTDataModule
from src.model import LitEfficientNet
from loguru import logger
import os
from src.utils.aws_s3_services import S3Handler
# Ensure the logs directory exists
os.makedirs("logs", exist_ok=True)
# Configure Loguru for logging
logger.add("logs/training.log", rotation="1 MB", level="INFO")
def main():
"""
Main training loop for the model with advanced configuration (CPU training).
"""
# Data Module
logger.info("Setting up data module...")
data_module = MNISTDataModule(batch_size=256)
# Model
logger.info("Setting up model...")
model = LitEfficientNet(model_name="tf_efficientnet_lite0", num_classes=10, lr=1e-3)
logger.info(model)
# Callbacks
logger.info("Setting up callbacks...")
checkpoint_callback = ModelCheckpoint(
monitor="val_acc",
dirpath="checkpoints/",
filename="best_model",
save_top_k=1,
mode="max",
auto_insert_metric_name=False,
verbose=True,
save_last=True,
enable_version_counter=False,
)
early_stopping_callback = EarlyStopping(
monitor="val_acc",
patience=5, # Extended patience for advanced models
mode="max",
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="epoch") # Log learning rate
rich_progress = RichProgressBar()
model_summary = ModelSummary(
max_depth=1
) # Show only the first level of model layers
# Loggers
logger.info("Setting up loggers...")
csv_logger = CSVLogger("logs/", name="mnist_csv")
tb_logger = TensorBoardLogger("logs/", name="mnist_tb")
# Trainer Configuration for CPU
logger.info("Setting up trainer...")
trainer = pl.Trainer(
max_epochs=2,
callbacks=[
checkpoint_callback,
early_stopping_callback,
lr_monitor,
rich_progress,
model_summary,
],
logger=[csv_logger, tb_logger],
deterministic=True,
accelerator="auto",
devices="auto",
)
# Train the model
logger.info("Training the model...")
trainer.fit(model, datamodule=data_module)
# Test the model
logger.info("Testing the model...")
data_module.setup(stage="test")
trainer.test(model, datamodule=data_module)
# write a checkpoints/train_done.flag
with open("checkpoints/train_done.flag", "w") as f:
f.write("Training done.")
# upload checkpoints to S3
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
s3_handler.upload_folder(
"checkpoints",
"checkpoints_test",
)
if __name__ == "__main__":
main()
|