|
import torch |
|
|
|
def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False): |
|
if upcast: |
|
q, k, v = q.float(), k.float(), v.float() |
|
if b is not None: |
|
b = b.float() |
|
|
|
if b is not None: |
|
if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]): |
|
b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2]) |
|
|
|
ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1) |
|
ns = torch.arange(k.shape[2], device=q.device) |
|
|
|
p = torch.matmul(q, k.transpose(2, 3)) |
|
p *= sm_scale |
|
if b is not None: |
|
p += b |
|
|
|
if causal: |
|
p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf")) |
|
|
|
p = torch.softmax(p.float(), dim=-1).to(q.dtype) |
|
if dropout_p > 0.0: |
|
p = torch.dropout(p, dropout_p, train=True) |
|
|
|
ref_out = torch.matmul(p, v) |
|
return ref_out |
|
|