File size: 3,086 Bytes
6056cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import Literal
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.mistral import MistralConfig
NVEMBED_TYPE = "nvembed"
LATENT_ATTENTION_TYPE = "latent_attention"
BIDIR_MISTRAL_TYPE = "bidir_mistral"
class NVEmbedConfig(PretrainedConfig):
model_type = "nvembed"
is_composition = False
def __init__(
self,
hidden_size=4096,
latent_attention_config=None,
text_config=None,
padding_side: Literal["right", "left"]="right",
add_pad_token: bool=True,
is_mask_instruction: bool = True,
add_eos: bool=True,
mask_type: str="b",
**kwargs,
):
if isinstance(latent_attention_config, dict):
latent_attention_config["model_type"] = (
latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
)
latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
elif latent_attention_config is None:
latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
self.latent_attention_config = latent_attention_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = None
self.hidden_size = hidden_size
self.text_config = text_config
self.padding_side = padding_side
self.is_mask_instruction = is_mask_instruction
self.add_pad_token = add_pad_token
self.add_eos = add_eos
self.mask_type = mask_type
super().__init__(**kwargs)
class LatentAttentionConfig(PretrainedConfig):
model_type = LATENT_ATTENTION_TYPE
is_composition = False
_name_or_path = "latent_attention"
def __init__(
self,
num_latents_value: int=512,
num_cross_heads: int=8,
output_normalize: bool=True,
hidden_dim: int=4096,
latent_dim: int=4096,
cross_dim_head: int=4096,
**kwargs,
):
self.num_latents_value = num_latents_value
self.num_cross_heads = num_cross_heads
self.output_normalize = output_normalize
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.cross_dim_head = cross_dim_head
class BidirectionalMistralConfig(MistralConfig):
model_type = BIDIR_MISTRAL_TYPE
keys_to_ignore_at_inference = ["past_key_values"]
AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)
NVEmbedConfig.register_for_auto_class()
LatentAttentionConfig.register_for_auto_class()
BidirectionalMistralConfig.register_for_auto_class()
|