File size: 5,566 Bytes
56df21f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import List, Tuple

from transformers import PretrainedConfig, AutoTokenizer


class MolmoVisionConfig(PretrainedConfig):
    def __init__(
        self,
        image_default_input_size: Tuple[int, int] = (336, 336),
        image_patch_size: int = 14,
        image_pos_patch_size: int = 14,
        image_emb_dim: int = 1024,
        image_num_heads: int = 16,
        image_num_key_value_heads: int = 16,
        image_num_layers: int = 23,
        image_head_dim: int = 64,
        image_mlp_dim: int = 4096,
        image_mlp_activations: str = "quick_gelu",
        residual_dropout: float = 0,
        image_num_pos: int = 577,
        image_norm_eps: float = 1e-5,
        float32_attention: bool = True,
        attention_type: str = "spda",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.image_default_input_size = image_default_input_size
        self.image_patch_size = image_patch_size
        self.image_pos_patch_size = image_pos_patch_size
        self.image_emb_dim = image_emb_dim
        self.image_num_heads = image_num_heads
        self.image_num_key_value_heads = image_num_key_value_heads
        self.image_num_layers = image_num_layers
        self.image_head_dim = image_head_dim
        self.image_mlp_dim = image_mlp_dim
        self.image_mlp_activations = image_mlp_activations
        self.residual_dropout = residual_dropout
        self.image_num_pos = image_num_pos
        self.image_norm_eps = image_norm_eps
        self.float32_attention = float32_attention

    @property
    def image_num_patch(self):
        h, w = self.image_default_input_size
        return h // self.image_patch_size, w // self.image_patch_size


class MolmoConfig(PretrainedConfig):
    model_type = "molmo"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=50304,
        embedding_size=50304,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        float32_attention=True,
        max_position_embeddings=2048,
        initializer_range=0.02,
        use_cache=True,
        layer_norm_eps: float = 1e-5,
        rope_theta=10000.0,
        clip_qkv=None,
        activation_type="silu",
        qkv_bias: bool = False,
        weight_tying: bool = False,
        use_position_ids: bool=True,
        tie_word_embeddings: bool=True,
        bias_for_layer_norm: bool=False,
        qk_layer_norm: bool=False,
        norm_after: bool = False,
        layer_norm_type: str="rms",
        vision_config: MolmoVisionConfig=None,
        vit_layers=(-2, -9),
        residual_dropout: float=0.0,
        embedding_dropout: float=0.0,
        attention_dropout: float=0.0,
        image_feature_dropout: float=0.0,
        additional_vocab_size=128,
        attention_type: str = "sdpa",
        image_padding_embed="pad_and_partial_pad",
        moe_num_experts=None,
        moe_top_k=None,
        normalize_input_embeds: bool=False,
        scale_logits: bool=False,
        **kwargs,
    ):
        if isinstance(vision_config, dict):
            self.vision_config = MolmoVisionConfig(**vision_config)
        elif vision_config is None:
            self.vision_config = MolmoVisionConfig()
        else:
            self.vision_config = vision_config

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.layer_norm_eps = layer_norm_eps
        self.weight_tying = weight_tying
        self.use_position_ids = use_position_ids
        self.qk_layer_norm = qk_layer_norm
        self.num_key_value_heads = num_key_value_heads
        self.float32_attention= float32_attention
        self.initializer_range = initializer_range
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.clip_qkv = clip_qkv
        self.activation_type = activation_type
        self.qkv_bias = qkv_bias
        self.norm_after = norm_after
        self.tie_word_embeddings = tie_word_embeddings
        self.layer_norm_type = layer_norm_type
        self.moe_num_experts = moe_num_experts
        self.moe_top_k = moe_top_k
        self.vit_layers = vit_layers
        self.residual_dropout = residual_dropout
        self.embedding_dropout = embedding_dropout
        self.attention_dropout = attention_dropout
        self.image_feature_dropout = image_feature_dropout
        self.image_padding_embed = image_padding_embed
        self.bias_for_layer_norm = bias_for_layer_norm
        self.additional_vocab_size = additional_vocab_size
        self.attention_type = attention_type
        self.normalize_input_embeds = normalize_input_embeds
        self.scale_logits = scale_logits

        super().__init__(
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    @property
    def effective_num_key_value_heads(self) -> int:
        if self.num_key_value_heads is None:
            return self.num_attention_heads
        else:
            return self.num_key_value_heads

    @property
    def image_num_patch(self):
        assert self.vision_config is not None
        return self.vision_config.image_num_patch


MolmoVisionConfig.register_for_auto_class()
MolmoConfig.register_for_auto_class()