update modeling_qwen.py
Browse files- modeling_qwen.py +4 -6
modeling_qwen.py
CHANGED
|
@@ -520,9 +520,7 @@ class QWenAttention(nn.Module):
|
|
| 520 |
|
| 521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
| 522 |
if attention_mask is not None:
|
| 523 |
-
attention_mask = attention_mask.expand(
|
| 524 |
-
-1, -1, causal_mask.size(2), -1
|
| 525 |
-
)
|
| 526 |
if causal_mask is not None:
|
| 527 |
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
| 528 |
else:
|
|
@@ -1330,14 +1328,14 @@ def apply_rotary_pos_emb(t, freqs):
|
|
| 1330 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
| 1331 |
the input embedding/hidden states
|
| 1332 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
| 1333 |
-
the cached cos/sin position embeddings
|
| 1334 |
"""
|
| 1335 |
rot_dim = freqs[0].shape[-1]
|
| 1336 |
cos, sin = freqs
|
| 1337 |
t_float = t.float()
|
| 1338 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
| 1339 |
-
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
| 1340 |
-
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
| 1341 |
# to the first rotary_dim of the input
|
| 1342 |
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
| 1343 |
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
|
|
|
| 520 |
|
| 521 |
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
| 522 |
if attention_mask is not None:
|
| 523 |
+
attention_mask = attention_mask.expand(-1, -1, key_size, -1)
|
|
|
|
|
|
|
| 524 |
if causal_mask is not None:
|
| 525 |
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
| 526 |
else:
|
|
|
|
| 1328 |
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
| 1329 |
the input embedding/hidden states
|
| 1330 |
freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]):
|
| 1331 |
+
the cached cos/sin position embeddings
|
| 1332 |
"""
|
| 1333 |
rot_dim = freqs[0].shape[-1]
|
| 1334 |
cos, sin = freqs
|
| 1335 |
t_float = t.float()
|
| 1336 |
if apply_rotary_emb_func is not None and t.is_cuda:
|
| 1337 |
+
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
| 1338 |
+
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
| 1339 |
# to the first rotary_dim of the input
|
| 1340 |
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
| 1341 |
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|