appledora commited on
Commit
6f6cf5a
1 Parent(s): 877b677

Upload recastmlp_llama/configuration_recastmlp_llama.py with huggingface_hub

Browse files
recastmlp_llama/configuration_recastmlp_llama.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class RECASTMLP_llama(PretrainedConfig):
5
+ model_type = "recastmlp_llama"
6
+ attribute_map = {
7
+ "hidden_size": "hidden_size",
8
+ "num_attention_heads": "num_attention_heads",
9
+ }
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=128256,
14
+ hidden_size=4096,
15
+ intermediate_size=14336,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ num_key_value_heads=8,
19
+ hidden_act="silu",
20
+ max_position_embeddings=131072,
21
+ initializer_range=0.02,
22
+ rms_norm_eps=1e-5,
23
+ use_cache=True,
24
+ pad_token_id=None,
25
+ bos_token_id=128000,
26
+ eos_token_id=128001,
27
+ pretraining_tp=1,
28
+ tie_word_embeddings=False,
29
+ rope_theta=500000.0,
30
+ rope_scaling={
31
+ "factor": 8.0,
32
+ "low_freq_factor": 1.0,
33
+ "high_freq_factor": 4.0,
34
+ "original_max_position_embeddings": 8192,
35
+ "rope_type": "llama3",
36
+ },
37
+ attention_bias=False,
38
+ attention_dropout=0.0,
39
+ mlp_bias=False,
40
+ # Template-specific configs
41
+ num_templates=4,
42
+ num_groups=8,
43
+ num_cf=1,
44
+ torch_dtype="bfloat16",
45
+ **kwargs
46
+ ):
47
+ self.vocab_size = vocab_size
48
+ self.max_position_embeddings = max_position_embeddings
49
+ self.hidden_size = hidden_size
50
+ self.intermediate_size = intermediate_size
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.pretraining_tp = pretraining_tp
58
+ self.use_cache = use_cache
59
+ self.mlp_bias = mlp_bias
60
+ self.attention_bias = attention_bias
61
+ self.attention_dropout = attention_dropout
62
+ self.rope_theta = rope_theta
63
+ self.rope_scaling = rope_scaling
64
+ self.torch_dtype = torch_dtype
65
+
66
+ # Template-specific configs
67
+ self.num_templates = num_templates
68
+ self.num_groups = num_groups
69
+ self.num_cf = num_cf
70
+
71
+ super().__init__(
72
+ pad_token_id=pad_token_id,
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ tie_word_embeddings=tie_word_embeddings,
76
+ **kwargs
77
+ )