|
import torch |
|
from triton_flash_atn import _attention |
|
|
|
|
|
batch_size = 2 |
|
num_heads = 4 |
|
seq_len = 128 |
|
head_dim = 64 |
|
|
|
|
|
q = torch.randn(batch_size, num_heads, seq_len, head_dim, |
|
dtype=torch.float16, device='cuda') |
|
k = torch.randn(batch_size, num_heads, seq_len, head_dim, |
|
dtype=torch.float16, device='cuda') |
|
v = torch.randn(batch_size, num_heads, seq_len, head_dim, |
|
dtype=torch.float16, device='cuda') |
|
|
|
|
|
causal = False |
|
sm_scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
attention = _attention.apply |
|
output = attention(q, k, v, causal, sm_scale) |
|
|
|
print(output) |
|
|