File size: 3,028 Bytes
8b617cc fac2d98 161bcb6 37293dc 77fca25 3355706 34c0a86 3355706 d69ba2b 1648279 d69ba2b 3355706 d69ba2b 3355706 8d288a2 161bcb6 fac2d98 161bcb6 fac2d98 161bcb6 8d288a2 3355706 77fca25 2bc1a5b 1427d5b 6c5fbe6 2bc1a5b 77fca25 3355706 77fca25 cf66547 58b0d4b cf66547 732851f c25ba79 5be8b55 8a49309 2bc1a5b 40a6362 9be92d1 5894f0e 1648279 77fca25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
if "Darwin" in platform.system():
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.4.0",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.5",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.13.1",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.0.1",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
},
)
|