import numpy as np import torch import torch.nn.functional as F from transformers import ( PretrainedConfig, PreTrainedModel, SiglipVisionConfig, SiglipVisionModel, XLMRobertaConfig, XLMRobertaModel, ) class MexmaSigLIPConfig(PretrainedConfig): def __init__( self, optimized: bool = False, **kwargs, ): super().__init__(**kwargs) self.optimized = optimized class MexmaSigLIP(PreTrainedModel): config_class = MexmaSigLIPConfig def __init__(self, config: MexmaSigLIPConfig): super().__init__(config) self.config = config text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA") if self.config.optimized: text_config._attn_implementation = "sdpa" self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False) self.text_projector = torch.nn.Linear(1024, 1152, bias=False) vision_congig = SiglipVisionConfig.from_pretrained( "google/siglip-so400m-patch14-384" ) if self.config.optimized: vision_congig._attn_implementation = "flash_attention_2" self.vision_model = SiglipVisionModel(vision_congig).vision_model self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10) def forward(self, image_inputs, input_ids, attention_mask, normalize=False): text_features = self.encode_texts(input_ids, attention_mask, normalize) image_features = self.encode_images(image_inputs, normalize) return { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale, "logit_bias": self.logit_bias, } def encode_images( self, pixel_values, normalize=False, ): features = self.vision_model(pixel_values).pooler_output return F.normalize(features, dim=-1) if normalize else features def encode_texts( self, input_ids, attention_mask, normalize=False, ): features = self.text_model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state[:, 0] features = self.text_projector(features) return F.normalize(features, dim=-1) if normalize else features def get_logits( self, input_ids, attention_mask, pixel_values, ): image_features = self.encode_images(pixel_values, normalize=True) text_features = self.encode_texts(input_ids, attention_mask, normalize=True) image_logits = ( self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias ) text_logits = image_logits.T return image_logits, text_logits