File size: 3,461 Bytes
4a9ac18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""TraVisionLM configuration"""

from transformers import PretrainedConfig
from transformers import logging, CONFIG_MAPPING
import warnings

logger = logging.get_logger(__name__)

class TraVisionLMConfig(PretrainedConfig):
    model_type = "travisionlm"
    is_composition = False

    def __init__(

            self,

            vision_config=None,

            text_config=None,

            ignore_index=-100,

            image_token_idx=50257,

            vocab_size=51282,

            projection_dim=768,

            hidden_size=1280,

            **kwargs, 

    ):
        self.ignore_index = ignore_index
        self.image_token_index = image_token_idx
        self._vocab_size = vocab_size
        self.projection_dim = projection_dim
        self.hidden_size = hidden_size
        self.vision_config = vision_config
        self.is_encoder_decoder = False
        if isinstance(self.vision_config, dict):
            vision_config["model_type"] = (
                vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
            )
            self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
        elif vision_config is None:
            self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
                attention_dropout=0.0,
                hidden_act="gelu_pytorch_tanh",
                hidden_size=768,
                image_size=256,
                intermediate_size=3072,
                layer_norm_eps=1e-06,
                num_attention_heads=12,
                num_channels=3,
                num_hidden_layers=12,
                patch_size=16,
            )
        self.vocab_size = vocab_size

        self.text_config = text_config

        if isinstance(self.text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gpt2"
            self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            self.text_config = CONFIG_MAPPING["gpt2"](
                activation_function="gelu_new",
                attn_pdrop=0.1,
                embd_pdrop=0.1,
                initializer_range=0.02,
                layer_norm_epsilon=1e-05,
                n_ctx=1024,
                n_embd=1280,
                n_head=20,
                n_layer=36,
                n_positions=1024,
                reorder_and_upcast_attn=False,
                resid_pdrop=0.1,
                scale_attn_by_inverse_layer_idx=False,
                scale_attn_weights=True,
                vocab_size=vocab_size
            )
        self.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
        self.pad_token_id = self.text_config.pad_token_id
        self.vision_config.projection_dim = projection_dim
        super().__init__(**kwargs)

    @property
    def vocab_size(self):
        warnings.warn(
            "The `vocab_size` attribute is deprecated and will be removed in v4.44, Please use `text_config.vocab_size` instead.",
            FutureWarning,
        )
        return self._vocab_size

    @vocab_size.setter
    def vocab_size(self, value):
        self._vocab_size = value

    def to_dict(self):
        output = super().to_dict()
        output.pop("_vocab_size", None)
        return output