mklf commited on
Commit
9201d24
1 Parent(s): d869ce1

Update config.json

Browse files

WizardCoder-Python-34B-V1.0 was trained by transformer 4.31.0, in transformer 4.31.0, `rope_theta` was not used in initialize RotaryEmbedding , so WizardCoder-Python-34B-V1.0 used the default value for `base` parameter , which is 10000

initialize here:
```
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
```


ref:
```
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
```

Files changed (1) hide show
  1. config.json +1 -1
config.json CHANGED
@@ -18,7 +18,7 @@
18
  "pretraining_tp": 1,
19
  "rms_norm_eps": 1e-05,
20
  "rope_scaling": null,
21
- "rope_theta": 1000000,
22
  "tie_word_embeddings": false,
23
  "torch_dtype": "float16",
24
  "transformers_version": "4.31.0",
 
18
  "pretraining_tp": 1,
19
  "rms_norm_eps": 1e-05,
20
  "rope_scaling": null,
21
+ "rope_theta": 10000,
22
  "tie_word_embeddings": false,
23
  "torch_dtype": "float16",
24
  "transformers_version": "4.31.0",