jupyterjazz commited on
Commit
1eb2361
1 Parent(s): 071760a

Update rotary.py

Browse files
Files changed (1) hide show
  1. rotary.py +9 -8
rotary.py CHANGED
@@ -494,14 +494,15 @@ class RotaryEmbedding(torch.nn.Module):
494
  @base.setter
495
  def base(self, new_base):
496
  new_base = float(new_base)
497
- if new_base > 0 and new_base != self._base:
498
- self._base = new_base
499
- self._update_cos_sin_cache(
500
- self._seq_len_cached,
501
- device=self.inv_freq.device,
502
- dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
503
- rotary_base_changed=True,
504
- )
 
505
  else:
506
  raise ValueError("Rotary base value must be positive")
507
 
 
494
  @base.setter
495
  def base(self, new_base):
496
  new_base = float(new_base)
497
+ if new_base > 0:
498
+ if self._base != new_base:
499
+ self._base = new_base
500
+ self._update_cos_sin_cache(
501
+ self._seq_len_cached,
502
+ device=self.inv_freq.device,
503
+ dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
504
+ rotary_base_changed=True,
505
+ )
506
  else:
507
  raise ValueError("Rotary base value must be positive")
508