Update config.json
Browse filesWizardCoder-Python-34B-V1.0 was trained by transformer 4.31.0, in transformer 4.31.0, `rope_theta` was not used in initialize RotaryEmbedding , so WizardCoder-Python-34B-V1.0 used the default value for `base` parameter , which is 10000
initialize here:
```
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
```
ref:
```
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
```
- config.json +1 -1
@@ -18,7 +18,7 @@
|
|
18 |
"pretraining_tp": 1,
|
19 |
"rms_norm_eps": 1e-05,
|
20 |
"rope_scaling": null,
|
21 |
-
"rope_theta":
|
22 |
"tie_word_embeddings": false,
|
23 |
"torch_dtype": "float16",
|
24 |
"transformers_version": "4.31.0",
|
|
|
18 |
"pretraining_tp": 1,
|
19 |
"rms_norm_eps": 1e-05,
|
20 |
"rope_scaling": null,
|
21 |
+
"rope_theta": 10000,
|
22 |
"tie_word_embeddings": false,
|
23 |
"torch_dtype": "float16",
|
24 |
"transformers_version": "4.31.0",
|