|
import einops |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.benchmark as benchmark |
|
from torch.backends.cuda import SDPBackend |
|
|
|
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer |
|
|
|
|
|
def benchmark_attn(): |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
|
t0 = benchmark.Timer( |
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
|
) |
|
return t0.blocked_autorange().mean * 1e6 |
|
|
|
|
|
batch_size = 32 |
|
max_sequence_len = 1024 |
|
num_heads = 32 |
|
embed_dimension = 32 |
|
|
|
dtype = torch.float16 |
|
|
|
query = torch.rand( |
|
batch_size, |
|
num_heads, |
|
max_sequence_len, |
|
embed_dimension, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
key = torch.rand( |
|
batch_size, |
|
num_heads, |
|
max_sequence_len, |
|
embed_dimension, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
value = torch.rand( |
|
batch_size, |
|
num_heads, |
|
max_sequence_len, |
|
embed_dimension, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
print(f"q/k/v shape:", query.shape, key.shape, value.shape) |
|
|
|
|
|
from torch.backends.cuda import SDPBackend, sdp_kernel |
|
|
|
|
|
backend_map = { |
|
SDPBackend.MATH: { |
|
"enable_math": True, |
|
"enable_flash": False, |
|
"enable_mem_efficient": False, |
|
}, |
|
SDPBackend.FLASH_ATTENTION: { |
|
"enable_math": False, |
|
"enable_flash": True, |
|
"enable_mem_efficient": False, |
|
}, |
|
SDPBackend.EFFICIENT_ATTENTION: { |
|
"enable_math": False, |
|
"enable_flash": False, |
|
"enable_mem_efficient": True, |
|
}, |
|
} |
|
|
|
from torch.profiler import ProfilerActivity, profile, record_function |
|
|
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] |
|
|
|
print( |
|
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" |
|
) |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("Default detailed stats"): |
|
for _ in range(25): |
|
o = F.scaled_dot_product_attention(query, key, value) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
|
|
print( |
|
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" |
|
) |
|
with sdp_kernel(**backend_map[SDPBackend.MATH]): |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("Math implmentation stats"): |
|
for _ in range(25): |
|
o = F.scaled_dot_product_attention(query, key, value) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
|
|
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): |
|
try: |
|
print( |
|
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" |
|
) |
|
except RuntimeError: |
|
print("FlashAttention is not supported. See warnings for reasons.") |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("FlashAttention stats"): |
|
for _ in range(25): |
|
o = F.scaled_dot_product_attention(query, key, value) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): |
|
try: |
|
print( |
|
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" |
|
) |
|
except RuntimeError: |
|
print("EfficientAttention is not supported. See warnings for reasons.") |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("EfficientAttention stats"): |
|
for _ in range(25): |
|
o = F.scaled_dot_product_attention(query, key, value) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
|
|
|
|
def run_model(model, x, context): |
|
return model(x, context) |
|
|
|
|
|
def benchmark_transformer_blocks(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
import torch.utils.benchmark as benchmark |
|
|
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): |
|
t0 = benchmark.Timer( |
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
|
) |
|
return t0.blocked_autorange().mean * 1e6 |
|
|
|
checkpoint = True |
|
compile = False |
|
|
|
batch_size = 32 |
|
h, w = 64, 64 |
|
context_len = 77 |
|
embed_dimension = 1024 |
|
context_dim = 1024 |
|
d_head = 64 |
|
|
|
transformer_depth = 4 |
|
|
|
n_heads = embed_dimension // d_head |
|
|
|
dtype = torch.float16 |
|
|
|
model_native = SpatialTransformer( |
|
embed_dimension, |
|
n_heads, |
|
d_head, |
|
context_dim=context_dim, |
|
use_linear=True, |
|
use_checkpoint=checkpoint, |
|
attn_type="softmax", |
|
depth=transformer_depth, |
|
sdp_backend=SDPBackend.FLASH_ATTENTION, |
|
).to(device) |
|
model_efficient_attn = SpatialTransformer( |
|
embed_dimension, |
|
n_heads, |
|
d_head, |
|
context_dim=context_dim, |
|
use_linear=True, |
|
depth=transformer_depth, |
|
use_checkpoint=checkpoint, |
|
attn_type="softmax-xformers", |
|
).to(device) |
|
if not checkpoint and compile: |
|
print("compiling models") |
|
model_native = torch.compile(model_native) |
|
model_efficient_attn = torch.compile(model_efficient_attn) |
|
|
|
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) |
|
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) |
|
|
|
from torch.profiler import ProfilerActivity, profile, record_function |
|
|
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] |
|
|
|
with torch.autocast("cuda"): |
|
print( |
|
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" |
|
) |
|
print( |
|
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" |
|
) |
|
|
|
print(75 * "+") |
|
print("NATIVE") |
|
print(75 * "+") |
|
torch.cuda.reset_peak_memory_stats() |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("NativeAttention stats"): |
|
for _ in range(25): |
|
model_native(x, c) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") |
|
|
|
print(75 * "+") |
|
print("Xformers") |
|
print(75 * "+") |
|
torch.cuda.reset_peak_memory_stats() |
|
with profile( |
|
activities=activities, record_shapes=False, profile_memory=True |
|
) as prof: |
|
with record_function("xformers stats"): |
|
for _ in range(25): |
|
model_efficient_attn(x, c) |
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) |
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") |
|
|
|
|
|
def test01(): |
|
|
|
from sgm.util import count_params |
|
|
|
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() |
|
print(count_params(conv)) |
|
linear = torch.nn.Linear(3, 32).cuda() |
|
print(count_params(linear)) |
|
|
|
print(conv.weight.shape) |
|
|
|
|
|
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) |
|
linear.bias = torch.nn.Parameter(conv.bias) |
|
|
|
print(linear.weight.shape) |
|
|
|
x = torch.randn(11, 3, 64, 64).cuda() |
|
|
|
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() |
|
print(xr.shape) |
|
out_linear = linear(xr) |
|
print(out_linear.mean(), out_linear.shape) |
|
|
|
out_conv = conv(x) |
|
print(out_conv.mean(), out_conv.shape) |
|
print("done with test01.\n") |
|
|
|
|
|
def test02(): |
|
|
|
import time |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.backends.cudnn.benchmark = True |
|
print("testing cosine flash attention...") |
|
DIM = 1024 |
|
SEQLEN = 4096 |
|
BS = 16 |
|
|
|
print(" softmax (vanilla) first...") |
|
model = BasicTransformerBlock( |
|
dim=DIM, |
|
n_heads=16, |
|
d_head=64, |
|
dropout=0.0, |
|
context_dim=None, |
|
attn_mode="softmax", |
|
).cuda() |
|
try: |
|
x = torch.randn(BS, SEQLEN, DIM).cuda() |
|
tic = time.time() |
|
y = model(x) |
|
toc = time.time() |
|
print(y.shape, toc - tic) |
|
except RuntimeError as e: |
|
|
|
print(str(e)) |
|
|
|
print("\n now flash-cosine...") |
|
model = BasicTransformerBlock( |
|
dim=DIM, |
|
n_heads=16, |
|
d_head=64, |
|
dropout=0.0, |
|
context_dim=None, |
|
attn_mode="flash-cosine", |
|
).cuda() |
|
x = torch.randn(BS, SEQLEN, DIM).cuda() |
|
tic = time.time() |
|
y = model(x) |
|
toc = time.time() |
|
print(y.shape, toc - tic) |
|
print("done with test02.\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
benchmark_transformer_blocks() |
|
|
|
print("done.") |
|
|