|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import CLIPPreTrainedModel, CLIPVisionConfig, CLIPVisionModel, SiglipTextConfig, SiglipTextModel
|
|
from transformers.models.clip.modeling_clip import CLIPOutput,clip_loss
|
|
from .configuration_mitsua_japanese_clip import MitsuaJapaneseCLIPConfig
|
|
|
|
class MitsuaJapaneseCLIPModel(CLIPPreTrainedModel):
|
|
config_class = MitsuaJapaneseCLIPConfig
|
|
def __init__(self, config: MitsuaJapaneseCLIPConfig):
|
|
CLIPPreTrainedModel.__init__(self, config)
|
|
|
|
if not isinstance(config.text_config, SiglipTextConfig):
|
|
raise TypeError(
|
|
"config.text_config is expected to be of type SiglipTextConfig but is of type"
|
|
f" {type(config.text_config)}."
|
|
)
|
|
|
|
if not isinstance(config.vision_config, CLIPVisionConfig):
|
|
raise TypeError(
|
|
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
|
f" {type(config.vision_config)}."
|
|
)
|
|
|
|
text_config = config.text_config
|
|
vision_config = config.vision_config
|
|
|
|
self.projection_dim = config.projection_dim
|
|
self.text_embed_dim = text_config.hidden_size
|
|
self.vision_embed_dim = vision_config.hidden_size
|
|
|
|
text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
|
|
self.text_model = text_model.text_model
|
|
|
|
vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
|
|
self.vision_model = vision_model.vision_model
|
|
|
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
|
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
|
|
|
|
|
self.post_init()
|
|
|
|
def get_text_features(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> torch.FloatTensor:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = text_outputs[1]
|
|
return pooled_output
|
|
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> torch.FloatTensor:
|
|
r"""
|
|
Returns:
|
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
Examples:
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, CLIPModel
|
|
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
>>> image_features = model.get_image_features(**inputs)
|
|
```"""
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
pooled_output = vision_outputs[1]
|
|
image_features = self.visual_projection(pooled_output)
|
|
|
|
return image_features
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
return_loss: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, CLIPOutput]:
|
|
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
image_embeds = vision_outputs[1]
|
|
image_embeds = self.visual_projection(image_embeds)
|
|
|
|
text_embeds = text_outputs[1]
|
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
|
|
|
|
logit_scale = self.logit_scale.exp()
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(
|
|
text_embeds.device
|
|
)
|
|
logits_per_image = logits_per_text.t()
|
|
|
|
loss = None
|
|
if return_loss:
|
|
loss = clip_loss(logits_per_text)
|
|
|
|
if not return_dict:
|
|
output = (
|
|
logits_per_image,
|
|
logits_per_text,
|
|
text_embeds,
|
|
image_embeds,
|
|
text_outputs,
|
|
vision_outputs,
|
|
)
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return CLIPOutput(
|
|
loss=loss,
|
|
logits_per_image=logits_per_image,
|
|
logits_per_text=logits_per_text,
|
|
text_embeds=text_embeds,
|
|
image_embeds=image_embeds,
|
|
text_model_output=text_outputs,
|
|
vision_model_output=vision_outputs,
|
|
) |