soutrik
orphan branch
c3d82b0
raw
history blame
3 kB
import lightning as pl
import torch.nn as nn
import torch
from timm import create_model
from torchmetrics.classification import Accuracy
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from loguru import logger
logger.add("logs/model.log", rotation="1 MB", level="INFO")
class LitEfficientNet(pl.LightningModule):
def __init__(
self,
model_name="tf_efficientnet_lite0",
num_classes=10,
lr=1e-3,
custom_loss=None,
):
"""
Initializes a CNN model from TIMM and integrates TorchMetrics.
Args:
model_name (str): TIMM model name (e.g., "tf_efficientnet_lite0").
num_classes (int): Number of output classes (e.g., 0–9 for MNIST).
lr (float): Learning rate for the optimizer.
custom_loss (callable, optional): Custom loss function. Defaults to CrossEntropyLoss.
"""
super().__init__()
self.lr = lr
self.model = create_model(
model_name,
pretrained=True,
num_classes=num_classes,
in_chans=1, # Set to 1 channel for grayscale input
)
self.loss_fn = custom_loss or nn.CrossEntropyLoss()
self.train_acc = Accuracy(num_classes=num_classes, task="multiclass")
self.val_acc = Accuracy(num_classes=num_classes, task="multiclass")
self.test_acc = Accuracy(num_classes=num_classes, task="multiclass")
logger.info(f"Model initialized with TIMM backbone: {model_name}")
logger.info(f"Number of output classes: {num_classes}")
def forward(self, x):
"""
Forward pass of the model.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Model predictions.
"""
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.train_acc.update(y_hat, y)
self.log("train_loss", loss, prog_bar=True, logger=True)
self.log("train_acc", self.train_acc, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.val_acc.update(y_hat, y)
self.log("val_loss", loss, prog_bar=True, logger=True)
self.log("val_acc", self.val_acc, prog_bar=True, logger=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.test_acc.update(y_hat, y)
self.log("test_acc", self.test_acc, prog_bar=True, logger=True)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9)
logger.info(f"Optimizer: Adam, Learning Rate: {self.lr}")
logger.info("Scheduler: StepLR with step_size=1 and gamma=0.9")
return [optimizer], [scheduler]