Update modeling_hf_nomic_bert.py
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -619,7 +619,9 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
619 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
620 |
"""
|
621 |
seqlen = qkv.shape[1]
|
622 |
-
if
|
|
|
|
|
623 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
624 |
elif isinstance(seqlen_offset, int):
|
625 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
|
619 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
620 |
"""
|
621 |
seqlen = qkv.shape[1]
|
622 |
+
if seqlen > self._seq_len_cached:
|
623 |
+
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
624 |
+
elif max_seqlen is not None:
|
625 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
626 |
elif isinstance(seqlen_offset, int):
|
627 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|