mm / src /genmo /lib /attn_imports.py
nruto's picture
Upload 31 files
d0bfdd6 verified
raw
history blame contribute delete
842 Bytes
from contextlib import contextmanager
import torch
try:
from flash_attn import flash_attn_varlen_qkvpacked_func as flash_varlen_qkvpacked_attn
except ImportError:
flash_varlen_qkvpacked_attn = None
try:
from sageattention import sageattn as sage_attn
except ImportError:
sage_attn = None
try:
from comfy.ldm.modules.attention import comfy_optimized_attention as comfy_attn
except ImportError:
comfy_attn = None
from torch.nn.attention import SDPBackend, sdpa_kernel
backends = []
if torch.cuda.get_device_properties(0).major < 7:
backends.append(SDPBackend.MATH)
if torch.cuda.get_device_properties(0).major >= 9.0:
backends.append(SDPBackend.CUDNN_ATTENTION)
else:
backends.append(SDPBackend.EFFICIENT_ATTENTION)
@contextmanager
def sdpa_attn_ctx():
with sdpa_kernel(backends):
yield