Update modeling_hf_nomic_bert.py
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -321,7 +321,8 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
-
|
|
|
325 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
326 |
config.n_positions = 2048
|
327 |
if num_labels:
|
@@ -554,6 +555,12 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
554 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
555 |
self.interleaved = interleaved
|
556 |
self.scale_base = scale_base
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
self._seq_len_cached = 0
|
559 |
self._cos_cached = None
|
|
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
+
if rotary_scaling_factor:
|
325 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
326 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
327 |
config.n_positions = 2048
|
328 |
if num_labels:
|
|
|
555 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
556 |
self.interleaved = interleaved
|
557 |
self.scale_base = scale_base
|
558 |
+
scale = (
|
559 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
560 |
+
if scale_base is not None
|
561 |
+
else None
|
562 |
+
)
|
563 |
+
self.register_buffer("scale", scale, persistent=False)
|
564 |
|
565 |
self._seq_len_cached = 0
|
566 |
self._cos_cached = None
|