Spaces:
Runtime error
Runtime error
import lightning as pl | |
import torch.nn as nn | |
import torch | |
from timm import create_model | |
from torchmetrics.classification import Accuracy | |
from torch.optim.lr_scheduler import StepLR | |
import torch.optim as optim | |
from loguru import logger | |
logger.add("logs/model.log", rotation="1 MB", level="INFO") | |
class LitEfficientNet(pl.LightningModule): | |
def __init__( | |
self, | |
model_name="tf_efficientnet_lite0", | |
num_classes=10, | |
lr=1e-3, | |
custom_loss=None, | |
): | |
""" | |
Initializes a CNN model from TIMM and integrates TorchMetrics. | |
Args: | |
model_name (str): TIMM model name (e.g., "tf_efficientnet_lite0"). | |
num_classes (int): Number of output classes (e.g., 0β9 for MNIST). | |
lr (float): Learning rate for the optimizer. | |
custom_loss (callable, optional): Custom loss function. Defaults to CrossEntropyLoss. | |
""" | |
super().__init__() | |
self.lr = lr | |
self.model = create_model( | |
model_name, | |
pretrained=True, | |
num_classes=num_classes, | |
in_chans=1, # Set to 1 channel for grayscale input | |
) | |
self.loss_fn = custom_loss or nn.CrossEntropyLoss() | |
self.train_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
self.val_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
self.test_acc = Accuracy(num_classes=num_classes, task="multiclass") | |
logger.info(f"Model initialized with TIMM backbone: {model_name}") | |
logger.info(f"Number of output classes: {num_classes}") | |
def forward(self, x): | |
""" | |
Forward pass of the model. | |
Args: | |
x (torch.Tensor): Input tensor. | |
Returns: | |
torch.Tensor: Model predictions. | |
""" | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = self.loss_fn(y_hat, y) | |
self.train_acc.update(y_hat, y) | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
self.log("train_acc", self.train_acc, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
loss = self.loss_fn(y_hat, y) | |
self.val_acc.update(y_hat, y) | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
self.log("val_acc", self.val_acc, prog_bar=True, logger=True) | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
y_hat = self(x) | |
self.test_acc.update(y_hat, y) | |
self.log("test_acc", self.test_acc, prog_bar=True, logger=True) | |
def configure_optimizers(self): | |
optimizer = optim.Adam(self.parameters(), lr=self.lr) | |
scheduler = StepLR(optimizer, step_size=1, gamma=0.9) | |
logger.info(f"Optimizer: Adam, Learning Rate: {self.lr}") | |
logger.info("Scheduler: StepLR with step_size=1 and gamma=0.9") | |
return [optimizer], [scheduler] | |