jupyterjazz commited on
Commit
3830381
1 Parent(s): 1752c7c

configuration

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (1) hide show
  1. configuration_xlm_roberta.py +85 -36
configuration_xlm_roberta.py CHANGED
@@ -1,44 +1,89 @@
1
- from transformers import PretrainedConfig
 
2
  import torch
 
 
3
 
4
  class XLMRobertaFlashConfig(PretrainedConfig):
5
  def __init__(
6
- self,
7
- vocab_size=30522,
8
- hidden_size=768,
9
- num_hidden_layers=12,
10
- num_attention_heads=12,
11
- intermediate_size=3072,
12
- hidden_act="gelu",
13
- hidden_dropout_prob=0.1,
14
- attention_probs_dropout_prob=0.1,
15
- max_position_embeddings=512,
16
- type_vocab_size=2,
17
- initializer_range=0.02,
18
- layer_norm_eps=1e-12,
19
- pad_token_id=1,
20
- bos_token_id=0,
21
- eos_token_id=2,
22
- position_embedding_type="absolute",
23
- rotary_emb_base=10000.0,
24
- use_cache=True,
25
- classifier_dropout=None,
26
- lora_adaptations=None,
27
- lora_prompts=None,
28
- lora_rank=4,
29
- lora_dropout_p=0.0,
30
- lora_alpha=1,
31
- lora_main_params_trainable=False,
32
- load_trained_adapters=False,
33
- use_flash_attn=True,
34
- torch_dtype=None,
35
- emb_pooler=None,
36
- matryoshka_dimensions=None,
37
- truncate_dim=None,
38
- **kwargs,
39
  ):
40
- super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
42
 
43
  self.vocab_size = vocab_size
44
  self.hidden_size = hidden_size
@@ -67,7 +112,11 @@ class XLMRobertaFlashConfig(PretrainedConfig):
67
  self.emb_pooler = emb_pooler
68
  self.matryoshka_dimensions = matryoshka_dimensions
69
  self.truncate_dim = truncate_dim
70
- if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
 
 
 
 
71
  self.torch_dtype = getattr(torch, torch_dtype)
72
  else:
73
  self.torch_dtype = torch_dtype
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
  import torch
4
+ from transformers import PretrainedConfig
5
+
6
 
7
  class XLMRobertaFlashConfig(PretrainedConfig):
8
  def __init__(
9
+ self,
10
+ vocab_size: int = 250002,
11
+ hidden_size: int = 1024,
12
+ num_hidden_layers: int = 24,
13
+ num_attention_heads: int = 16,
14
+ intermediate_size: int = 4096,
15
+ hidden_act: str = "gelu",
16
+ hidden_dropout_prob: float = 0.1,
17
+ attention_probs_dropout_prob: float = 0.1,
18
+ max_position_embeddings: int = 8194,
19
+ type_vocab_size: int = 1,
20
+ initializer_range: float = 0.02,
21
+ layer_norm_eps: float = 1e-05,
22
+ pad_token_id: int = 1,
23
+ bos_token_id: int = 0,
24
+ eos_token_id: int = 2,
25
+ position_embedding_type: str = "rotary",
26
+ rotary_emb_base: float = 10000.0,
27
+ use_cache: bool = True,
28
+ classifier_dropout: Optional[float] = None,
29
+ lora_adaptations: Optional[List[str]] = None,
30
+ lora_prompts: Optional[Dict[str, str]] = None,
31
+ lora_rank: int = 4,
32
+ lora_dropout_p: float = 0.0,
33
+ lora_alpha: int = 1,
34
+ lora_main_params_trainable: bool = False,
35
+ load_trained_adapters: bool = False,
36
+ use_flash_attn: bool = True,
37
+ torch_dtype: Optional[Union[str, torch.dtype]] = None,
38
+ emb_pooler: Optional[str] = None,
39
+ matryoshka_dimensions: Optional[List[int]] = None,
40
+ truncate_dim: Optional[int] = None,
41
+ **kwargs: Dict[str, Any],
42
  ):
43
+ """
44
+ Initialize the XLMRobertaFlashConfig configuration.
45
+
46
+ Args:
47
+ vocab_size (int): Size of the vocabulary.
48
+ hidden_size (int): Dimensionality of the encoder layers and the pooler layer.
49
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (int): Number of attention heads for each attention layer in the Transformer encoder.
51
+ intermediate_size (int): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer.
52
+ hidden_act (str): The activation function to use.
53
+ hidden_dropout_prob (float): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
54
+ attention_probs_dropout_prob (float): The dropout ratio for the attention probabilities.
55
+ max_position_embeddings (int): The maximum length of the position embeddings.
56
+ type_vocab_size (int): The vocabulary size of the token type ids.
57
+ initializer_range (float): The standard deviation for initializing all weight matrices.
58
+ layer_norm_eps (float): The epsilon used by the layer normalization layers.
59
+ pad_token_id (int): The ID of the padding token.
60
+ bos_token_id (int): The ID of the beginning-of-sequence token.
61
+ eos_token_id (int): The ID of the end-of-sequence token.
62
+ position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
63
+ rotary_emb_base (float): Base for rotary embeddings.
64
+ use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
65
+ classifier_dropout (Optional[float]): The dropout ratio for the classification head.
66
+ lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
67
+ lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
68
+ lora_rank (int): Rank for LoRA adaptations.
69
+ lora_dropout_p (float): Dropout probability for LoRA adaptations.
70
+ lora_alpha (int): Alpha parameter for LoRA.
71
+ lora_main_params_trainable (bool): Whether to make the main model parameters trainable when using LoRA.
72
+ load_trained_adapters (bool): Whether to load trained adapters.
73
+ use_flash_attn (bool): Whether to use FlashAttention.
74
+ torch_dtype (Optional[Union[str, torch.dtype]]): Data type for the tensors.
75
+ emb_pooler (Optional[str]): Pooling layer configuration.
76
+ matryoshka_dimensions (Optional[List[int]]): Configuration for matryoshka dimension reduction.
77
+ truncate_dim (Optional[int]): Dimension to truncate embeddings to, if any.
78
+ **kwargs (Dict[str, Any]): Additional keyword arguments passed to the configuration.
79
+ """
80
 
81
+ super().__init__(
82
+ pad_token_id=pad_token_id,
83
+ bos_token_id=bos_token_id,
84
+ eos_token_id=eos_token_id,
85
+ **kwargs,
86
+ )
87
 
88
  self.vocab_size = vocab_size
89
  self.hidden_size = hidden_size
 
112
  self.emb_pooler = emb_pooler
113
  self.matryoshka_dimensions = matryoshka_dimensions
114
  self.truncate_dim = truncate_dim
115
+ if (
116
+ torch_dtype
117
+ and hasattr(torch, torch_dtype)
118
+ and type(getattr(torch, torch_dtype)) is torch.dtype
119
+ ):
120
  self.torch_dtype = getattr(torch, torch_dtype)
121
  else:
122
  self.torch_dtype = torch_dtype