File size: 2,773 Bytes
bc7df4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from torch import nn
from transformers import SiglipVisionModel, SiglipVisionConfig
# 384/14=27.428571428571427 is not an integer, so the actual pos embedding is 729, sqrt(729)*14=378. So the implementation uses the floor
class SiglipEncoder(nn.Module):
def __init__(self, vision_config):
super(SiglipEncoder, self).__init__()
config = SiglipVisionConfig(**vision_config)
self.model = SiglipVisionModel(config)
def forward(self, images):
outputs = self.model(images).last_hidden_state
return outputs
class GLU(nn.Module):
def __init__(self, args, in_features):
super().__init__()
self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
self.norm1 = nn.LayerNorm(args.hidden_size)
self.act1 = nn.GELU()
self.act2 = nn.functional.silu
self.dense_h_to_4h = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.dense_4h_to_h = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def forward(self, x):
x = self.linear_proj(x)
x = self.act1(self.norm1(x))
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
x = self.dense_4h_to_h(x)
return x
class Adapter(nn.Module):
def __init__(self, eva_hidden_size, args):
super().__init__()
self.boi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
self.eoi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
self.conv = nn.Conv2d(in_channels=eva_hidden_size, out_channels=args.hidden_size, kernel_size=2, stride=2)
self.linear_proj = GLU(args, args.hidden_size)
def forward(self, image_emb):
b, s, e = image_emb.shape # (b, 6400, 1792)
grid_size = int(s**0.5)
image_emb = image_emb.view(b, grid_size, grid_size, e).permute(0,3,1,2) # (b, 1792, 80, 80)
image_emb = self.conv(image_emb) # (b, 4096, 40, 40)
image_emb = image_emb.flatten(2).transpose(1, 2) # (b, 1600, 4096)
image_emb = self.linear_proj(image_emb) # (b, 1600, 6656)
image_emb = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1)
return image_emb
class VisionModel(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.dtype = config.torch_dtype
self.vit = SiglipEncoder(config.vision_config)
self.adapter = Adapter(config.vision_config['hidden_size'], config)
def forward(self, image):
image = image.to(self.dtype)
vit_output = self.vit(image)
return self.adapter(vit_output).to(self.dtype)
|