Fix typo in modeling_grok
Browse files- modeling_grok.py +1 -1
modeling_grok.py
CHANGED
@@ -273,7 +273,7 @@ class GrokAttention(nn.Module):
|
|
273 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
274 |
|
275 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.attn_output_multiplier
|
276 |
-
|
277 |
|
278 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
279 |
raise ValueError(
|
|
|
273 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
274 |
|
275 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.attn_output_multiplier
|
276 |
+
attn_weights = 30 * torch.tanh(attn_weights / 30)
|
277 |
|
278 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
279 |
raise ValueError(
|