yangapku commited on
Commit
5d52159
·
1 Parent(s): b980709

update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +4 -3
modeling_qwen.py CHANGED
@@ -175,6 +175,7 @@ class FlashSelfAttention(torch.nn.Module):
175
  assert all((i.is_cuda for i in (q, k, v)))
176
  batch_size, seqlen_q = q.shape[0], q.shape[1]
177
  seqlen_k = k.shape[1]
 
178
 
179
  q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
180
  cu_seqlens_q = torch.arange(
@@ -187,11 +188,11 @@ class FlashSelfAttention(torch.nn.Module):
187
 
188
  if attention_mask is not None:
189
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
190
- v = v[indices_k]
191
- if self.training or q.size(0) == k.size(0):
192
  q = q[indices_k]
193
  cu_seqlens_q = cu_seqlens_k
194
  seqlen_q = seqlen_k
 
195
  else:
196
  cu_seqlens_k = torch.arange(
197
  0,
@@ -222,7 +223,7 @@ class FlashSelfAttention(torch.nn.Module):
222
  causal=is_causal,
223
  )
224
  if attention_mask is not None and seqlen_q == seqlen_k:
225
- output = self.pad_input(output, indices_k, batch_size, seqlen_q)
226
  else:
227
  new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
228
  output = output.view(new_shape)
 
175
  assert all((i.is_cuda for i in (q, k, v)))
176
  batch_size, seqlen_q = q.shape[0], q.shape[1]
177
  seqlen_k = k.shape[1]
178
+ seqlen_out = seqlen_q
179
 
180
  q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
181
  cu_seqlens_q = torch.arange(
 
188
 
189
  if attention_mask is not None:
190
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
191
+ if q.size(0) == v.size(0):
 
192
  q = q[indices_k]
193
  cu_seqlens_q = cu_seqlens_k
194
  seqlen_q = seqlen_k
195
+ v = v[indices_k]
196
  else:
197
  cu_seqlens_k = torch.arange(
198
  0,
 
223
  causal=is_causal,
224
  )
225
  if attention_mask is not None and seqlen_q == seqlen_k:
226
+ output = self.pad_input(output, indices_k, batch_size, seqlen_out)
227
  else:
228
  new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
229
  output = output.view(new_shape)