File size: 2,104 Bytes
6353c49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from transformers import BertModel, PreTrainedModel, BertConfig, AutoModel
from typing import List
from .configuration_marqo_arctic_bge_chimera_m import ChimeraConfig


class Chimera(PreTrainedModel):
    config_class = ChimeraConfig

    def __init__(self, config: ChimeraConfig):
        super().__init__(config)
        bert_config = BertConfig(
            vocab_size=30522,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=512,
            type_vocab_size=2,
            initializer_range=0.02,
            layer_norm_eps=1e-12,
        )

        self.model = nn.ModuleDict(
            {
                "model_0": BertModel(bert_config),
                "model_1": BertModel(bert_config),
            }
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        token_type_ids: torch.Tensor = None,
    ) -> torch.Tensor:
        embeddings = []
        for _, model in self.model.items():
            model_output = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )
            pooled_output = model_output[0][:, 0]
            embeddings.append(pooled_output)

        return torch.cat(embeddings, dim=-1)

    def load_weights_from_automodels(
        self, in_models: List[str], has_pooling_layer: List[bool]
    ):
        model_list = []
        for i, model_name in enumerate(in_models):
            model = AutoModel.from_pretrained(
                model_name,
                add_pooling_layer=has_pooling_layer[i],
                trust_remote_code=True,
            )
            model.eval()
            model_list.append(model)

        self.model = nn.ModuleDict(
            {f"model_{i}": model for i, model in enumerate(model_list)}
        )