|
|
|
from typing import Literal
|
|
from transformers import AutoConfig
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.models.auto import CONFIG_MAPPING
|
|
from transformers.models.mistral import MistralConfig
|
|
|
|
NVEMBED_TYPE = "nvembed"
|
|
LATENT_ATTENTION_TYPE = "latent_attention"
|
|
BIDIR_MISTRAL_TYPE = "bidir_mistral"
|
|
|
|
class NVEmbedConfig(PretrainedConfig):
|
|
model_type = "nvembed"
|
|
is_composition = False
|
|
|
|
def __init__(
|
|
self,
|
|
latent_attention_config=None,
|
|
text_config=None,
|
|
padding_side: Literal["right", "left"]="right",
|
|
add_pad_token: bool=True,
|
|
is_mask_instruction: bool = True,
|
|
add_eos: bool=True,
|
|
mask_type: str="b",
|
|
**kwargs,
|
|
):
|
|
if isinstance(latent_attention_config, dict):
|
|
latent_attention_config["model_type"] = (
|
|
latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
|
|
)
|
|
latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
|
|
elif latent_attention_config is None:
|
|
latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
|
|
|
|
self.latent_attention_config = latent_attention_config
|
|
|
|
if isinstance(text_config, dict):
|
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
|
elif text_config is None:
|
|
text_config = None
|
|
|
|
self.text_config = text_config
|
|
self.padding_side = padding_side
|
|
self.is_mask_instruction = is_mask_instruction
|
|
self.add_pad_token = add_pad_token
|
|
self.add_eos = add_eos
|
|
self.mask_type = mask_type
|
|
if "hidden_size" in kwargs:
|
|
self.hidden_size = kwargs["hidden_size"]
|
|
else:
|
|
self.hidden_size = 4096
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class LatentAttentionConfig(PretrainedConfig):
|
|
model_type = LATENT_ATTENTION_TYPE
|
|
is_composition = False
|
|
_name_or_path = "latent_attention"
|
|
|
|
def __init__(
|
|
self,
|
|
num_latents_value: int=512,
|
|
num_cross_heads: int=8,
|
|
output_normalize: bool=True,
|
|
hidden_dim: int=4096,
|
|
latent_dim: int=4096,
|
|
cross_dim_head: int=4096,
|
|
**kwargs,
|
|
):
|
|
self.num_latents_value = num_latents_value
|
|
self.num_cross_heads = num_cross_heads
|
|
self.output_normalize = output_normalize
|
|
self.hidden_dim = hidden_dim
|
|
self.latent_dim = latent_dim
|
|
self.cross_dim_head = cross_dim_head
|
|
|
|
|
|
class BidirectionalMistralConfig(MistralConfig):
|
|
model_type = BIDIR_MISTRAL_TYPE
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
|
|
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
|
|
AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)
|
|
|
|
NVEmbedConfig.register_for_auto_class()
|
|
LatentAttentionConfig.register_for_auto_class()
|
|
BidirectionalMistralConfig.register_for_auto_class()
|
|
|