import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchmetrics.functional import retrieval_average_precision | |
import pytorch_lightning as pl | |
from src.dinov2.models.vision_transformer import vit_base | |
from functools import partial | |
# from src.clip import clip | |
from src.options import opts | |
def freeze_model(m): | |
m.requires_grad_(False) | |
def freeze_all_but_bn(m): | |
if not isinstance(m, torch.nn.LayerNorm): | |
if hasattr(m, 'weight') and m.weight is not None: | |
m.weight.requires_grad_(False) | |
if hasattr(m, 'bias') and m.bias is not None: | |
m.bias.requires_grad_(False) | |
else: | |
print("LayerNorm") | |
class Model(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.opts = opts | |
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) | |
print("self.dino", self.dino) | |
# Prompt Engineering | |
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) | |
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) | |
self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y) | |
self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss( | |
distance_function=self.distance_fn, margin=0.2) | |
self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2) | |
self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True) | |
self.best_metric = -1e3 | |
# normalization layer for the representations z1 and z2 | |
# = nn.BatchNorm1d(self.opts.prompt_dim, affine=False) | |
def configure_optimizers(self): | |
if self.opts.model_type == 'one_encoder': | |
model_params = list(self.dino.parameters()) | |
else: | |
model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters()) | |
optimizer = torch.optim.Adam([ | |
{'params': model_params, 'lr': self.opts.clip_LN_lr}, | |
{'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}]) | |
return optimizer | |
def forward(self, data, dtype='image'): | |
if dtype == 'image': | |
feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1)) | |
else: | |
feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1)) | |
return feat |