File size: 734 Bytes
5c0cb68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from triton_flash_atn import _attention

# Define dimensions
batch_size = 2
num_heads = 4
seq_len = 128
head_dim = 64

# Create random input tensors for Q, K, V
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
                dtype=torch.float16, device='cuda')

# Define whether the attention is causal and the scaling factor
causal = False
sm_scale = 1.0 / (head_dim ** 0.5)

# Apply flash attention
attention = _attention.apply
output = attention(q, k, v, causal, sm_scale)

print(output)