File size: 3,310 Bytes
8b617cc
 
fac2d98
 
161bcb6
 
37293dc
77fca25
3355706
 
 
 
 
34c0a86
3355706
d69ba2b
 
 
 
 
1648279
d69ba2b
3355706
 
 
 
d69ba2b
3355706
 
8d288a2
161bcb6
fac2d98
039e2a0
fac2d98
 
 
 
 
 
 
 
 
 
 
 
 
 
039e2a0
 
 
 
 
 
161bcb6
 
8d288a2
3355706
 
 
 
 
77fca25
 
2bc1a5b
1427d5b
6c5fbe6
 
2bc1a5b
77fca25
 
3355706
77fca25
cf66547
039e2a0
cf66547
732851f
039e2a0
732851f
c25ba79
039e2a0
8a49309
2bc1a5b
40a6362
05b398a
40a6362
9be92d1
 
 
5894f0e
 
 
1648279
 
 
dd449c5
 
 
77fca25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""setup.py for axolotl"""

import platform
import re
from importlib.metadata import PackageNotFoundError, version

from setuptools import find_packages, setup


def parse_requirements():
    _install_requires = []
    _dependency_links = []
    with open("./requirements.txt", encoding="utf-8") as requirements_file:
        lines = [r.strip() for r in requirements_file.readlines()]
        for line in lines:
            is_extras = (
                "flash-attn" in line
                or "flash-attention" in line
                or "deepspeed" in line
                or "mamba-ssm" in line
                or "lion-pytorch" in line
            )
            if line.startswith("--extra-index-url"):
                # Handle custom index URLs
                _, url = line.split()
                _dependency_links.append(url)
            elif not is_extras and line and line[0] != "#":
                # Handle standard packages
                _install_requires.append(line)

    try:
        if "Darwin" in platform.system():
            _install_requires.pop(_install_requires.index("xformers==0.0.23.post1"))
        else:
            torch_version = version("torch")
            _install_requires.append(f"torch=={torch_version}")

            version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
            if version_match:
                major, minor, patch = version_match.groups()
                major, minor = int(major), int(minor)
                patch = (
                    int(patch) if patch is not None else 0
                )  # Default patch to 0 if not present
            else:
                raise ValueError("Invalid version format")

            if (major, minor) >= (2, 3):
                _install_requires.pop(_install_requires.index("xformers==0.0.23.post1"))
                _install_requires.append("xformers>=0.0.26.post1")
            elif (major, minor) >= (2, 2):
                _install_requires.pop(_install_requires.index("xformers==0.0.23.post1"))
                _install_requires.append("xformers>=0.0.25.post1")
    except PackageNotFoundError:
        pass

    return _install_requires, _dependency_links


install_requires, dependency_links = parse_requirements()


setup(
    name="axolotl",
    version="0.4.0",
    description="LLM Trainer",
    long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
    package_dir={"": "src"},
    packages=find_packages(),
    install_requires=install_requires,
    dependency_links=dependency_links,
    extras_require={
        "flash-attn": [
            "flash-attn==2.5.8",
        ],
        "fused-dense-lib": [
            "fused-dense-lib  @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
        ],
        "deepspeed": [
            "deepspeed==0.14.2",
            "deepspeed-kernels",
        ],
        "mamba-ssm": [
            "mamba-ssm==1.2.0.post1",
        ],
        "auto-gptq": [
            "auto-gptq==0.5.1",
        ],
        "mlflow": [
            "mlflow",
        ],
        "lion-pytorch": [
            "lion-pytorch==0.1.2",
        ],
        "galore": [
            "galore_torch",
        ],
    },
)