File size: 2,400 Bytes
265ae36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import retrieval_average_precision
import pytorch_lightning as pl

from src.dinov2.models.vision_transformer import vit_base

from functools import partial

# from src.clip import clip
from src.options import opts

def freeze_model(m):
    m.requires_grad_(False)

def freeze_all_but_bn(m):
    if not isinstance(m, torch.nn.LayerNorm):
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.requires_grad_(False)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.requires_grad_(False)
    else:
        print("LayerNorm")

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.opts = opts

        self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) 
        print("self.dino", self.dino)

        # Prompt Engineering
        self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))
        self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim))

        self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y)
        self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss(
             distance_function=self.distance_fn, margin=0.2)
        
        self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2)

        self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

        self.best_metric = -1e3
        # normalization layer for the representations z1 and z2
        # self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False)

    def configure_optimizers(self):
        if self.opts.model_type == 'one_encoder':
            model_params = list(self.dino.parameters())
        else:
            model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters())

        optimizer = torch.optim.Adam([
            {'params': model_params, 'lr': self.opts.clip_LN_lr},
            {'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}])
        return optimizer

    def forward(self, data, dtype='image'):
        if dtype == 'image':
            feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1))
        else:
            feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1))
        return feat