|
|
|
import warnings |
|
import os |
|
from pathlib import Path |
|
|
|
from packaging.version import parse, Version |
|
from setuptools import setup, find_packages |
|
import subprocess |
|
|
|
|
|
import torch |
|
from torch.utils.cpp_extension import ( |
|
BuildExtension, |
|
CppExtension, |
|
CUDAExtension, |
|
CUDA_HOME, |
|
) |
|
|
|
PACKAGE_NAME = "blackmamba" |
|
VERSION = "0.0.1" |
|
|
|
with open("README.md", "r", encoding="utf-8") as fh: |
|
long_description = fh.read() |
|
|
|
|
|
|
|
this_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" |
|
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" |
|
|
|
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" |
|
|
|
|
|
def get_cuda_bare_metal_version(cuda_dir): |
|
raw_output = subprocess.check_output( |
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True |
|
) |
|
output = raw_output.split() |
|
release_idx = output.index("release") + 1 |
|
bare_metal_version = parse(output[release_idx].split(",")[0]) |
|
|
|
return raw_output, bare_metal_version |
|
|
|
|
|
def check_if_cuda_home_none(global_option: str) -> None: |
|
if CUDA_HOME is not None: |
|
return |
|
|
|
|
|
warnings.warn( |
|
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " |
|
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " |
|
"only images whose names contain 'devel' will provide nvcc." |
|
) |
|
|
|
|
|
def append_nvcc_threads(nvcc_extra_args): |
|
return nvcc_extra_args + ["--threads", "4"] |
|
|
|
|
|
ext_modules = [] |
|
if not SKIP_CUDA_BUILD: |
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) |
|
TORCH_MAJOR = int(torch.__version__.split(".")[0]) |
|
TORCH_MINOR = int(torch.__version__.split(".")[1]) |
|
|
|
check_if_cuda_home_none(PACKAGE_NAME) |
|
|
|
cc_flag = [] |
|
if CUDA_HOME is not None: |
|
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) |
|
if bare_metal_version < Version("11.6"): |
|
raise RuntimeError( |
|
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " |
|
"Note: make sure nvcc has a supported version by running nvcc -V." |
|
) |
|
|
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_70,code=sm_70") |
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_80,code=sm_80") |
|
if bare_metal_version >= Version("11.8"): |
|
cc_flag.append("-gencode") |
|
cc_flag.append("arch=compute_90,code=sm_90") |
|
|
|
|
|
|
|
|
|
if FORCE_CXX11_ABI: |
|
torch._C._GLIBCXX_USE_CXX11_ABI = True |
|
|
|
ext_modules.append( |
|
CUDAExtension( |
|
name="selective_scan_cuda", |
|
sources=[ |
|
"csrc/selective_scan/selective_scan.cpp", |
|
"csrc/selective_scan/selective_scan_fwd_fp32.cu", |
|
"csrc/selective_scan/selective_scan_fwd_fp16.cu", |
|
"csrc/selective_scan/selective_scan_fwd_bf16.cu", |
|
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu", |
|
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", |
|
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu", |
|
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", |
|
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu", |
|
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", |
|
], |
|
extra_compile_args={ |
|
"cxx": ["-O3", "-std=c++17"], |
|
"nvcc": append_nvcc_threads( |
|
[ |
|
"-O3", |
|
"-std=c++17", |
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
"-U__CUDA_NO_BFLOAT16_OPERATORS__", |
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", |
|
"-U__CUDA_NO_BFLOAT162_OPERATORS__", |
|
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", |
|
"--expt-relaxed-constexpr", |
|
"--expt-extended-lambda", |
|
"--use_fast_math", |
|
"--ptxas-options=-v", |
|
"-lineinfo", |
|
] |
|
+ cc_flag |
|
), |
|
}, |
|
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], |
|
) |
|
) |
|
|
|
|
|
setup( |
|
name=PACKAGE_NAME, |
|
version=VERSION, |
|
description="Blackmamba state-space + MoE model", |
|
long_description=long_description, |
|
long_description_content_type="text/markdown", |
|
packages=find_packages(include=['ops'],), |
|
exclude=( |
|
"csrc", |
|
"blackmamba.egg-info", |
|
), |
|
ext_modules=ext_modules, |
|
cmdclass={"build_ext": BuildExtension}, |
|
python_requires=">=3.7", |
|
install_requires=[ |
|
"torch", |
|
"packaging", |
|
"ninja", |
|
"einops", |
|
"triton", |
|
"transformers", |
|
"causal_conv1d>=1.1.0", |
|
], |
|
) |