Spaces:
Runtime error
Runtime error
calculating
commited on
Commit
•
896572e
1
Parent(s):
168d5e3
committing...
Browse files- transformer.py +6 -7
transformer.py
CHANGED
@@ -191,13 +191,12 @@ class GQA(nn.Module):
|
|
191 |
def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
192 |
k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
193 |
v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
)
|
201 |
x = x.transpose(1, 2).contiguous()
|
202 |
return x
|
203 |
|
|
|
191 |
def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
192 |
k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
193 |
v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
194 |
+
x = F.scaled_dot_product_attention(
|
195 |
+
q.transpose(1, 2),
|
196 |
+
k.transpose(1, 2),
|
197 |
+
v.transpose(1, 2),
|
198 |
+
is_causal=False if (q.size(1) != k.size(1)) else self.causal,
|
199 |
+
)
|
|
|
200 |
x = x.transpose(1, 2).contiguous()
|
201 |
return x
|
202 |
|