marqo-chimera-arctic-bge-m / modeling_marqo_arctic_bge_chimera_m.py
OwenElliott's picture
Update modeling_marqo_arctic_bge_chimera_m.py
2e4c665 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, PreTrainedModel, BertConfig, AutoModel
from typing import List
from .configuration_marqo_arctic_bge_chimera_m import ChimeraConfig
class Chimera(PreTrainedModel):
config_class = ChimeraConfig
def __init__(self, config: ChimeraConfig):
super().__init__(config)
bert_config = BertConfig(
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
)
self.model = nn.ModuleDict(
{
"model_0": BertModel(bert_config),
"model_1": BertModel(bert_config),
}
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
) -> torch.Tensor:
embeddings = []
for _, model in self.model.items():
model_output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
pooled_output = model_output[0][:, 0]
pooled_output = F.normalize(pooled_output, p=2, dim=-1)
embeddings.append(pooled_output)
return torch.cat(embeddings, dim=-1)
def load_weights_from_automodels(
self, in_models: List[str], has_pooling_layer: List[bool]
):
model_list = []
for i, model_name in enumerate(in_models):
model = AutoModel.from_pretrained(
model_name,
add_pooling_layer=has_pooling_layer[i],
trust_remote_code=True,
)
model.eval()
model_list.append(model)
self.model = nn.ModuleDict(
{f"model_{i}": model for i, model in enumerate(model_list)}
)