jupyterjazz
commited on
Commit
•
1eb2361
1
Parent(s):
071760a
Update rotary.py
Browse files
rotary.py
CHANGED
@@ -494,14 +494,15 @@ class RotaryEmbedding(torch.nn.Module):
|
|
494 |
@base.setter
|
495 |
def base(self, new_base):
|
496 |
new_base = float(new_base)
|
497 |
-
if new_base > 0
|
498 |
-
self._base
|
499 |
-
|
500 |
-
self.
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
|
|
505 |
else:
|
506 |
raise ValueError("Rotary base value must be positive")
|
507 |
|
|
|
494 |
@base.setter
|
495 |
def base(self, new_base):
|
496 |
new_base = float(new_base)
|
497 |
+
if new_base > 0:
|
498 |
+
if self._base != new_base:
|
499 |
+
self._base = new_base
|
500 |
+
self._update_cos_sin_cache(
|
501 |
+
self._seq_len_cached,
|
502 |
+
device=self.inv_freq.device,
|
503 |
+
dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
|
504 |
+
rotary_base_changed=True,
|
505 |
+
)
|
506 |
else:
|
507 |
raise ValueError("Rotary base value must be positive")
|
508 |
|