|
from typing import Literal |
|
from transformers import AutoConfig |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.models.auto import CONFIG_MAPPING |
|
from transformers.models.llama import LlamaConfig |
|
|
|
GIGAREMBED_TYPE = "gigarembed" |
|
LATENT_ATTENTION_TYPE = "latent_attention" |
|
BIDIR_LLAMA_TYPE = "bidir_llama" |
|
|
|
class GigarEmbedConfig(PretrainedConfig): |
|
model_type = "gigarembed" |
|
is_composition = False |
|
|
|
def __init__( |
|
self, |
|
latent_attention_config=None, |
|
text_config=None, |
|
padding_side: Literal["right", "left"]="right", |
|
add_pad_token: bool=True, |
|
is_mask_instruction: bool = True, |
|
add_eos: bool=True, |
|
mask_type: str="b", |
|
**kwargs, |
|
): |
|
if isinstance(latent_attention_config, dict): |
|
latent_attention_config["model_type"] = ( |
|
latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE |
|
) |
|
latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config) |
|
elif latent_attention_config is None: |
|
latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]() |
|
|
|
self.latent_attention_config = latent_attention_config |
|
|
|
if isinstance(text_config, dict): |
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" |
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
|
elif text_config is None: |
|
text_config = None |
|
|
|
self.text_config = text_config |
|
self.padding_side = padding_side |
|
self.is_mask_instruction = is_mask_instruction |
|
self.add_pad_token = add_pad_token |
|
self.add_eos = add_eos |
|
self.mask_type = mask_type |
|
if "hidden_size" in kwargs: |
|
self.hidden_size = kwargs["hidden_size"] |
|
else: |
|
self.hidden_size = 2560 |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class LatentAttentionConfig(PretrainedConfig): |
|
model_type = LATENT_ATTENTION_TYPE |
|
is_composition = False |
|
_name_or_path = "latent_attention" |
|
|
|
def __init__( |
|
self, |
|
num_latents_value: int=512, |
|
num_cross_heads: int=8, |
|
output_normalize: bool=True, |
|
hidden_dim: int=2560, |
|
latent_dim: int=2560, |
|
cross_dim_head: int=2560, |
|
**kwargs, |
|
): |
|
self.num_latents_value = num_latents_value |
|
self.num_cross_heads = num_cross_heads |
|
self.output_normalize = output_normalize |
|
self.hidden_dim = hidden_dim |
|
self.latent_dim = latent_dim |
|
self.cross_dim_head = cross_dim_head |
|
|
|
|
|
class BidirectionalLlamaConfig(LlamaConfig): |
|
model_type = BIDIR_LLAMA_TYPE |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
|
|
AutoConfig.register(GIGAREMBED_TYPE, GigarEmbedConfig) |
|
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig) |
|
AutoConfig.register(BIDIR_LLAMA_TYPE, BidirectionalLlamaConfig) |
|
|
|
GigarEmbedConfig.register_for_auto_class() |
|
LatentAttentionConfig.register_for_auto_class() |
|
BidirectionalLlamaConfig.register_for_auto_class() |
|
|