from typing import Literal from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import CONFIG_MAPPING from transformers.models.llama import LlamaConfig GIGAREMBED_TYPE = "gigarembed" LATENT_ATTENTION_TYPE = "latent_attention" BIDIR_LLAMA_TYPE = "bidir_llama" class GigarEmbedConfig(PretrainedConfig): model_type = "gigarembed" is_composition = False def __init__( self, latent_attention_config=None, text_config=None, padding_side: Literal["right", "left"]="right", add_pad_token: bool=True, is_mask_instruction: bool = True, add_eos: bool=True, mask_type: str="b", **kwargs, ): if isinstance(latent_attention_config, dict): latent_attention_config["model_type"] = ( latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE ) latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config) elif latent_attention_config is None: latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]() self.latent_attention_config = latent_attention_config if isinstance(text_config, dict): text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: text_config = None self.text_config = text_config self.padding_side = padding_side self.is_mask_instruction = is_mask_instruction self.add_pad_token = add_pad_token self.add_eos = add_eos self.mask_type = mask_type if "hidden_size" in kwargs: self.hidden_size = kwargs["hidden_size"] else: self.hidden_size = 2560 super().__init__(**kwargs) class LatentAttentionConfig(PretrainedConfig): model_type = LATENT_ATTENTION_TYPE is_composition = False _name_or_path = "latent_attention" def __init__( self, num_latents_value: int=512, num_cross_heads: int=8, output_normalize: bool=True, hidden_dim: int=2560, latent_dim: int=2560, cross_dim_head: int=2560, **kwargs, ): self.num_latents_value = num_latents_value self.num_cross_heads = num_cross_heads self.output_normalize = output_normalize self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.cross_dim_head = cross_dim_head class BidirectionalLlamaConfig(LlamaConfig): model_type = BIDIR_LLAMA_TYPE keys_to_ignore_at_inference = ["past_key_values"] AutoConfig.register(GIGAREMBED_TYPE, GigarEmbedConfig) AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig) AutoConfig.register(BIDIR_LLAMA_TYPE, BidirectionalLlamaConfig) GigarEmbedConfig.register_for_auto_class() LatentAttentionConfig.register_for_auto_class() BidirectionalLlamaConfig.register_for_auto_class()