File size: 2,104 Bytes
6353c49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.nn as nn
from transformers import BertModel, PreTrainedModel, BertConfig, AutoModel
from typing import List
from .configuration_marqo_arctic_bge_chimera_m import ChimeraConfig
class Chimera(PreTrainedModel):
config_class = ChimeraConfig
def __init__(self, config: ChimeraConfig):
super().__init__(config)
bert_config = BertConfig(
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
)
self.model = nn.ModuleDict(
{
"model_0": BertModel(bert_config),
"model_1": BertModel(bert_config),
}
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
) -> torch.Tensor:
embeddings = []
for _, model in self.model.items():
model_output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
pooled_output = model_output[0][:, 0]
embeddings.append(pooled_output)
return torch.cat(embeddings, dim=-1)
def load_weights_from_automodels(
self, in_models: List[str], has_pooling_layer: List[bool]
):
model_list = []
for i, model_name in enumerate(in_models):
model = AutoModel.from_pretrained(
model_name,
add_pooling_layer=has_pooling_layer[i],
trust_remote_code=True,
)
model.eval()
model_list.append(model)
self.model = nn.ModuleDict(
{f"model_{i}": model for i, model in enumerate(model_list)}
)
|