TraVisionLM-base / configuration_travisionlm.py
ucsahin's picture
Base model registered and uploaded
4a9ac18 verified
raw
history blame contribute delete
No virus
3.46 kB
"""TraVisionLM configuration"""
from transformers import PretrainedConfig
from transformers import logging, CONFIG_MAPPING
import warnings
logger = logging.get_logger(__name__)
class TraVisionLMConfig(PretrainedConfig):
model_type = "travisionlm"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_idx=50257,
vocab_size=51282,
projection_dim=768,
hidden_size=1280,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_idx
self._vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
attention_dropout=0.0,
hidden_act="gelu_pytorch_tanh",
hidden_size=768,
image_size=256,
intermediate_size=3072,
layer_norm_eps=1e-06,
num_attention_heads=12,
num_channels=3,
num_hidden_layers=12,
patch_size=16,
)
self.vocab_size = vocab_size
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gpt2"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
self.text_config = CONFIG_MAPPING["gpt2"](
activation_function="gelu_new",
attn_pdrop=0.1,
embd_pdrop=0.1,
initializer_range=0.02,
layer_norm_epsilon=1e-05,
n_ctx=1024,
n_embd=1280,
n_head=20,
n_layer=36,
n_positions=1024,
reorder_and_upcast_attn=False,
resid_pdrop=0.1,
scale_attn_by_inverse_layer_idx=False,
scale_attn_weights=True,
vocab_size=vocab_size
)
self.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.pad_token_id = self.text_config.pad_token_id
self.vision_config.projection_dim = projection_dim
super().__init__(**kwargs)
@property
def vocab_size(self):
warnings.warn(
"The `vocab_size` attribute is deprecated and will be removed in v4.44, Please use `text_config.vocab_size` instead.",
FutureWarning,
)
return self._vocab_size
@vocab_size.setter
def vocab_size(self, value):
self._vocab_size = value
def to_dict(self):
output = super().to_dict()
output.pop("_vocab_size", None)
return output