File size: 2,971 Bytes
c3d82b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import lightning as pl
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
    RichProgressBar,
)
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import ModelSummary
from src.dataloader import MNISTDataModule
from src.model import LitEfficientNet
from loguru import logger
import os
from src.utils.aws_s3_services import S3Handler

# Ensure the logs directory exists
os.makedirs("logs", exist_ok=True)

# Configure Loguru for logging
logger.add("logs/training.log", rotation="1 MB", level="INFO")


def main():
    """
    Main training loop for the model with advanced configuration (CPU training).
    """
    # Data Module
    logger.info("Setting up data module...")
    data_module = MNISTDataModule(batch_size=256)

    # Model
    logger.info("Setting up model...")
    model = LitEfficientNet(model_name="tf_efficientnet_lite0", num_classes=10, lr=1e-3)
    logger.info(model)

    # Callbacks
    logger.info("Setting up callbacks...")
    checkpoint_callback = ModelCheckpoint(
        monitor="val_acc",
        dirpath="checkpoints/",
        filename="best_model",
        save_top_k=1,
        mode="max",
        auto_insert_metric_name=False,
        verbose=True,
        save_last=True,
        enable_version_counter=False,
    )
    early_stopping_callback = EarlyStopping(
        monitor="val_acc",
        patience=5,  # Extended patience for advanced models
        mode="max",
        verbose=True,
    )
    lr_monitor = LearningRateMonitor(logging_interval="epoch")  # Log learning rate
    rich_progress = RichProgressBar()
    model_summary = ModelSummary(
        max_depth=1
    )  # Show only the first level of model layers

    # Loggers
    logger.info("Setting up loggers...")
    csv_logger = CSVLogger("logs/", name="mnist_csv")
    tb_logger = TensorBoardLogger("logs/", name="mnist_tb")

    # Trainer Configuration for CPU
    logger.info("Setting up trainer...")
    trainer = pl.Trainer(
        max_epochs=2,
        callbacks=[
            checkpoint_callback,
            early_stopping_callback,
            lr_monitor,
            rich_progress,
            model_summary,
        ],
        logger=[csv_logger, tb_logger],
        deterministic=True,
        accelerator="auto",
        devices="auto",
    )

    # Train the model
    logger.info("Training the model...")
    trainer.fit(model, datamodule=data_module)

    # Test the model
    logger.info("Testing the model...")
    data_module.setup(stage="test")
    trainer.test(model, datamodule=data_module)

    # write a checkpoints/train_done.flag
    with open("checkpoints/train_done.flag", "w") as f:
        f.write("Training done.")

    # upload checkpoints to S3
    s3_handler = S3Handler(bucket_name="deep-bucket-s3")
    s3_handler.upload_folder(
        "checkpoints",
        "checkpoints_test",
    )


if __name__ == "__main__":
    main()