Spaces:
Sleeping
Sleeping
File size: 2,400 Bytes
265ae36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import retrieval_average_precision
import pytorch_lightning as pl
from src.dinov2.models.vision_transformer import vit_base
from functools import partial
# from src.clip import clip
from src.options import opts
def freeze_model(m):
m.requires_grad_(False)
def freeze_all_but_bn(m):
if not isinstance(m, torch.nn.LayerNorm):
if hasattr(m, 'weight') and m.weight is not None:
m.weight.requires_grad_(False)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.requires_grad_(False)
else:
print("LayerNorm")
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.opts = opts
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0)
print("self.dino", self.dino)
# Prompt Engineering
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss(
distance_function=self.distance_fn, margin=0.2)
self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2)
self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
self.best_metric = -1e3
# normalization layer for the representations z1 and z2
# self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False)
def configure_optimizers(self):
if self.opts.model_type == 'one_encoder':
model_params = list(self.dino.parameters())
else:
model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())
optimizer = torch.optim.Adam([
{'params': model_params, 'lr': self.opts.clip_LN_lr},
{'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}])
return optimizer
def forward(self, data, dtype='image'):
if dtype == 'image':
feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1))
else:
feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1))
return feat |