|
from .configuration_clip_camembert import CLIPTextCamembertConfig |
|
from transformers import ( |
|
CamembertModel, |
|
CLIPTextModelWithProjection, |
|
) |
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput |
|
import torch |
|
from torch import nn |
|
from typing import Any, Optional, Tuple, Union |
|
|
|
|
|
class CLIPTextCamembertModelWithProjection(CLIPTextModelWithProjection): |
|
config_class = CLIPTextCamembertConfig |
|
|
|
def __init__(self, config: CLIPTextCamembertConfig): |
|
super().__init__(config) |
|
|
|
self.text_model = CamembertModel(config) |
|
|
|
self.text_projection = nn.Linear( |
|
config.hidden_size, config.projection_dim, bias=False |
|
) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CLIPTextModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
text_outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
pooled_output = text_outputs[1] |
|
|
|
text_embeds = self.text_projection(pooled_output) |
|
|
|
if not return_dict: |
|
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] |
|
return tuple(output for output in outputs if output is not None) |
|
|
|
return CLIPTextModelOutput( |
|
text_embeds=text_embeds, |
|
last_hidden_state=text_outputs.last_hidden_state, |
|
hidden_states=text_outputs.hidden_states, |
|
attentions=text_outputs.attentions, |
|
) |
|
|
|
def converter_weight( |
|
self, path_model="airesearch/wangchanberta-base-att-spm-uncased" |
|
): |
|
r""" |
|
converter weight from airesearch/wangchanberta-base-att-spm-uncased |
|
""" |
|
pretrained_state_dict = CamembertModel.from_pretrained(path_model).state_dict() |
|
|
|
self.text_model.load_state_dict(pretrained_state_dict) |