Upload modeling_rwkv6qwen2.py with huggingface_hub
Browse files- modeling_rwkv6qwen2.py +4 -3
modeling_rwkv6qwen2.py
CHANGED
@@ -204,6 +204,7 @@ class RWKV6State(Cache):
|
|
204 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
205 |
|
206 |
from fla.ops.gla.chunk import chunk_gla
|
|
|
207 |
|
208 |
class RWKV6Attention(nn.Module):
|
209 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
@@ -360,8 +361,8 @@ class RWKV6Attention(nn.Module):
|
|
360 |
scale = query_states.shape[-1] ** -0.5
|
361 |
output_final_state = not self.training and use_cache and past_key_value is not None
|
362 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
363 |
-
attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
364 |
-
|
365 |
|
366 |
if output_final_state:
|
367 |
past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
|
@@ -1207,4 +1208,4 @@ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
|
|
1207 |
end_logits=end_logits,
|
1208 |
hidden_states=outputs.hidden_states,
|
1209 |
attentions=outputs.attentions,
|
1210 |
-
)
|
|
|
204 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
205 |
|
206 |
from fla.ops.gla.chunk import chunk_gla
|
207 |
+
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
208 |
|
209 |
class RWKV6Attention(nn.Module):
|
210 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
|
|
361 |
scale = query_states.shape[-1] ** -0.5
|
362 |
output_final_state = not self.training and use_cache and past_key_value is not None
|
363 |
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
364 |
+
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
365 |
+
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
366 |
|
367 |
if output_final_state:
|
368 |
past_key_value.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
|
|
|
1208 |
end_logits=end_logits,
|
1209 |
hidden_states=outputs.hidden_states,
|
1210 |
attentions=outputs.attentions,
|
1211 |
+
)
|