|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
|
|
from src import config |
|
from src import loss as loss_utils |
|
from src import metrics |
|
from src import models |
|
|
|
|
|
class LightningModule(pl.LightningModule): |
|
def __init__( |
|
self, |
|
vision_encoder: models.TinyCLIPVisionEncoder, |
|
text_encoder: models.TinyCLIPTextEncoder, |
|
loss_fn: nn.Module, |
|
hyper_parameters: config.TrainerConfig, |
|
len_train_dl: int, |
|
) -> None: |
|
super().__init__() |
|
self.vision_encoder = vision_encoder |
|
self.text_encoder = text_encoder |
|
self.loss_fn = loss_fn |
|
self.hyper_parameters = hyper_parameters |
|
self.len_train_dl = len_train_dl |
|
|
|
def common_step(self, batch: dict[str, torch.Tensor], step_kind: str) -> torch.Tensor: |
|
image_features = self.vision_encoder(batch["images"]) |
|
text_features = self.text_encoder( |
|
{key: value for key, value in batch.items() if key != "images"} |
|
) |
|
similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features) |
|
|
|
loss = self.loss_fn(similarity_matrix, image_features, text_features) |
|
|
|
img_acc, cap_acc = metrics.metrics(similarity_matrix) |
|
|
|
self.log(f"{step_kind}_loss", loss, on_step=False, on_epoch=True) |
|
self.log(f"{step_kind}_img_acc", img_acc, on_step=False, on_epoch=True, prog_bar=True) |
|
self.log(f"{step_kind}_cap_acc", cap_acc, on_step=False, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def training_step(self, batch: tuple[torch.Tensor, list[str]], *args: list) -> torch.Tensor: |
|
loss = self.common_step(batch, step_kind="training") |
|
return loss |
|
|
|
def validation_step(self, batch: tuple[torch.Tensor, list[str]], *args: list): |
|
_ = self.common_step(batch, step_kind="training") |
|
|
|
def configure_optimizers(self): |
|
vision_params = [ |
|
{ |
|
"params": self.vision_encoder.projection.parameters(), |
|
"lr": self.hyper_parameters.learning_rate, |
|
}, |
|
] |
|
caption_params = [ |
|
{ |
|
"params": self.text_encoder.projection.parameters(), |
|
"lr": self.hyper_parameters.learning_rate, |
|
}, |
|
] |
|
loss_params = [ |
|
{ |
|
"params": self.loss_fn.parameters(), |
|
"lr": self.hyper_parameters.learning_rate, |
|
}, |
|
] |
|
|
|
if not self.hyper_parameters._model_config.freeze_text_base: |
|
caption_params += [ |
|
{ |
|
"params": self.text_encoder.base.parameters(), |
|
"lr": self.hyper_parameters.learning_rate / 2, |
|
}, |
|
] |
|
|
|
if not self.hyper_parameters._model_config.freeze_vision_base: |
|
vision_params += [ |
|
{ |
|
"params": self.vision_encoder.base.parameters(), |
|
"lr": self.hyper_parameters.learning_rate / 2, |
|
}, |
|
] |
|
|
|
optimizer = torch.optim.Adam( |
|
vision_params + caption_params + loss_params, lr=self.hyper_parameters.learning_rate |
|
) |
|
|
|
if self.hyper_parameters.lr_scheduler: |
|
scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=self.hyper_parameters.learning_rate, |
|
total_steps=int(self.trainer.estimated_stepping_batches), |
|
) |
|
return [optimizer], [scheduler] |
|
else: |
|
return optimizer |
|
|
|
def on_epoch_end(self): |
|
if self.current_epoch == 0: |
|
for p in self.vision_encoder.base.parameters(): |
|
p.requires_grad = True |
|
self.vision_encoder.base.train() |
|
|