|
from contextlib import contextmanager |
|
|
|
import torch |
|
|
|
try: |
|
from flash_attn import flash_attn_varlen_qkvpacked_func as flash_varlen_qkvpacked_attn |
|
except ImportError: |
|
flash_varlen_qkvpacked_attn = None |
|
|
|
try: |
|
from sageattention import sageattn as sage_attn |
|
except ImportError: |
|
sage_attn = None |
|
|
|
try: |
|
from comfy.ldm.modules.attention import comfy_optimized_attention as comfy_attn |
|
except ImportError: |
|
comfy_attn = None |
|
|
|
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
backends = [] |
|
if torch.cuda.get_device_properties(0).major < 7: |
|
backends.append(SDPBackend.MATH) |
|
if torch.cuda.get_device_properties(0).major >= 9.0: |
|
backends.append(SDPBackend.CUDNN_ATTENTION) |
|
else: |
|
backends.append(SDPBackend.EFFICIENT_ATTENTION) |
|
|
|
|
|
@contextmanager |
|
def sdpa_attn_ctx(): |
|
with sdpa_kernel(backends): |
|
yield |
|
|