calculating commited on
Commit
896572e
1 Parent(s): 168d5e3

committing...

Browse files
Files changed (1) hide show
  1. transformer.py +6 -7
transformer.py CHANGED
@@ -191,13 +191,12 @@ class GQA(nn.Module):
191
  def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
192
  k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
193
  v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
194
- with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext():
195
- x = F.scaled_dot_product_attention(
196
- q.transpose(1, 2),
197
- k.transpose(1, 2),
198
- v.transpose(1, 2),
199
- is_causal=False if (q.size(1) != k.size(1)) else self.causal,
200
- )
201
  x = x.transpose(1, 2).contiguous()
202
  return x
203
 
 
191
  def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
192
  k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
193
  v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
194
+ x = F.scaled_dot_product_attention(
195
+ q.transpose(1, 2),
196
+ k.transpose(1, 2),
197
+ v.transpose(1, 2),
198
+ is_causal=False if (q.size(1) != k.size(1)) else self.causal,
199
+ )
 
200
  x = x.transpose(1, 2).contiguous()
201
  return x
202