import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchmetrics.functional import retrieval_average_precision import pytorch_lightning as pl from src.dinov2.models.vision_transformer import vit_base from functools import partial # from src.clip import clip from src.options import opts def freeze_model(m): m.requires_grad_(False) def freeze_all_but_bn(m): if not isinstance(m, torch.nn.LayerNorm): if hasattr(m, 'weight') and m.weight is not None: m.weight.requires_grad_(False) if hasattr(m, 'bias') and m.bias is not None: m.bias.requires_grad_(False) else: print("LayerNorm") class Model(pl.LightningModule): def __init__(self): super().__init__() self.opts = opts self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) print("self.dino", self.dino) # Prompt Engineering self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y) self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss( distance_function=self.distance_fn, margin=0.2) self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2) self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True) self.best_metric = -1e3 # normalization layer for the representations z1 and z2 # self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False) def configure_optimizers(self): if self.opts.model_type == 'one_encoder': model_params = list(self.dino.parameters()) else: model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters()) optimizer = torch.optim.Adam([ {'params': model_params, 'lr': self.opts.clip_LN_lr}, {'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}]) return optimizer def forward(self, data, dtype='image'): if dtype == 'image': feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1)) else: feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1)) return feat