SmerkyG commited on
Commit
d60e6fb
1 Parent(s): b0fcd94

Upload modeling_rwkv6qwen2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_rwkv6qwen2.py +4 -3
modeling_rwkv6qwen2.py CHANGED
@@ -204,6 +204,7 @@ class RWKV6State(Cache):
204
  # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
205
 
206
  from fla.ops.gla.chunk import chunk_gla
 
207
 
208
  class RWKV6Attention(nn.Module):
209
  def __init__(self, config, layer_idx: Optional[int] = None):
@@ -360,8 +361,8 @@ class RWKV6Attention(nn.Module):
360
  scale = query_states.shape[-1] ** -0.5
361
  output_final_state = not self.training and use_cache and past_key_value is not None
362
  #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
363
- attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
364
- #attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, input_kv_state, output_final_state)
365
 
366
  if output_final_state:
367
  past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
@@ -1207,4 +1208,4 @@ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
1207
  end_logits=end_logits,
1208
  hidden_states=outputs.hidden_states,
1209
  attentions=outputs.attentions,
1210
- )
 
204
  # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
205
 
206
  from fla.ops.gla.chunk import chunk_gla
207
+ from fla.ops.gla.fused_recurrent import fused_recurrent_gla
208
 
209
  class RWKV6Attention(nn.Module):
210
  def __init__(self, config, layer_idx: Optional[int] = None):
 
361
  scale = query_states.shape[-1] ** -0.5
362
  output_final_state = not self.training and use_cache and past_key_value is not None
363
  #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
364
+ #attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
365
+ attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
366
 
367
  if output_final_state:
368
  past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
 
1208
  end_logits=end_logits,
1209
  hidden_states=outputs.hidden_states,
1210
  attentions=outputs.attentions,
1211
+ )