wenbopan winglian commited on
Commit
e07347b
1 Parent(s): bcdc9b1

Remove seq_len arg in rotary_emb (#1443)

Browse files

* remove seq_len in llama rotary_emb

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn(
284
  # [bsz, nh, q_len, hd]
285
  # pylint: disable=duplicate-code
286
 
287
- kv_seq_len = key_states.shape[-2]
288
- if past_key_value is not None:
289
- kv_seq_len += past_key_value[0].shape[-2]
290
- cos, sin = self.rotary_emb(
291
- value_states, seq_len=kv_seq_len, position_ids=position_ids
292
- )
293
  query_states, key_states = apply_rotary_pos_emb(
294
  query_states, key_states, cos, sin, position_ids
295
  )
@@ -435,13 +430,7 @@ def flashattn_forward(
435
  # [bsz, q_len, nh, hd]
436
  # [bsz, nh, q_len, hd]
437
 
438
- kv_seq_len = key_states.shape[-2]
439
- if past_key_value is not None:
440
- kv_seq_len += past_key_value[0].shape[-2]
441
-
442
- cos, sin = self.rotary_emb(
443
- value_states, seq_len=kv_seq_len, position_ids=position_ids
444
- )
445
  query_states, key_states = apply_rotary_pos_emb(
446
  query_states, key_states, cos, sin, position_ids
447
  )
 
284
  # [bsz, nh, q_len, hd]
285
  # pylint: disable=duplicate-code
286
 
287
+ cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
 
 
 
 
 
288
  query_states, key_states = apply_rotary_pos_emb(
289
  query_states, key_states, cos, sin, position_ids
290
  )
 
430
  # [bsz, q_len, nh, hd]
431
  # [bsz, nh, q_len, hd]
432
 
433
+ cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
 
 
 
 
 
 
434
  query_states, key_states = apply_rotary_pos_emb(
435
  query_states, key_states, cos, sin, position_ids
436
  )
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -80,11 +80,7 @@ def xformers_forward(
80
  # [bsz, q_len, nh, hd]
81
  # [bsz, nh, q_len, hd]
82
 
83
- kv_seq_len = key_states.shape[-2]
84
- if past_key_value is not None:
85
- kv_seq_len += past_key_value[0].shape[-2]
86
-
87
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
88
  query_states, key_states = apply_rotary_pos_emb(
89
  query_states, key_states, cos, sin, position_ids
90
  )
 
80
  # [bsz, q_len, nh, hd]
81
  # [bsz, nh, q_len, hd]
82
 
83
+ cos, sin = self.rotary_emb(value_states)
 
 
 
 
84
  query_states, key_states = apply_rotary_pos_emb(
85
  query_states, key_states, cos, sin, position_ids
86
  )