Remove seq_len arg in rotary_emb (#1443)
Browse files* remove seq_len in llama rotary_emb
* chore: lint
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn(
|
|
284 |
# [bsz, nh, q_len, hd]
|
285 |
# pylint: disable=duplicate-code
|
286 |
|
287 |
-
|
288 |
-
if past_key_value is not None:
|
289 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
290 |
-
cos, sin = self.rotary_emb(
|
291 |
-
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
292 |
-
)
|
293 |
query_states, key_states = apply_rotary_pos_emb(
|
294 |
query_states, key_states, cos, sin, position_ids
|
295 |
)
|
@@ -435,13 +430,7 @@ def flashattn_forward(
|
|
435 |
# [bsz, q_len, nh, hd]
|
436 |
# [bsz, nh, q_len, hd]
|
437 |
|
438 |
-
|
439 |
-
if past_key_value is not None:
|
440 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
441 |
-
|
442 |
-
cos, sin = self.rotary_emb(
|
443 |
-
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
444 |
-
)
|
445 |
query_states, key_states = apply_rotary_pos_emb(
|
446 |
query_states, key_states, cos, sin, position_ids
|
447 |
)
|
|
|
284 |
# [bsz, nh, q_len, hd]
|
285 |
# pylint: disable=duplicate-code
|
286 |
|
287 |
+
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
|
|
|
|
|
|
|
|
|
|
288 |
query_states, key_states = apply_rotary_pos_emb(
|
289 |
query_states, key_states, cos, sin, position_ids
|
290 |
)
|
|
|
430 |
# [bsz, q_len, nh, hd]
|
431 |
# [bsz, nh, q_len, hd]
|
432 |
|
433 |
+
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
query_states, key_states = apply_rotary_pos_emb(
|
435 |
query_states, key_states, cos, sin, position_ids
|
436 |
)
|
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
@@ -80,11 +80,7 @@ def xformers_forward(
|
|
80 |
# [bsz, q_len, nh, hd]
|
81 |
# [bsz, nh, q_len, hd]
|
82 |
|
83 |
-
|
84 |
-
if past_key_value is not None:
|
85 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
86 |
-
|
87 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
88 |
query_states, key_states = apply_rotary_pos_emb(
|
89 |
query_states, key_states, cos, sin, position_ids
|
90 |
)
|
|
|
80 |
# [bsz, q_len, nh, hd]
|
81 |
# [bsz, nh, q_len, hd]
|
82 |
|
83 |
+
cos, sin = self.rotary_emb(value_states)
|
|
|
|
|
|
|
|
|
84 |
query_states, key_states = apply_rotary_pos_emb(
|
85 |
query_states, key_states, cos, sin, position_ids
|
86 |
)
|