File size: 2,045 Bytes
8b617cc
 
161bcb6
 
37293dc
77fca25
3355706
 
 
 
 
34c0a86
3355706
 
 
 
 
c25ba79
 
732851f
c25ba79
 
 
 
3355706
 
8d288a2
161bcb6
 
 
 
 
 
 
8d288a2
3355706
 
 
 
 
77fca25
 
2bc1a5b
772cd87
6c5fbe6
 
2bc1a5b
77fca25
 
3355706
77fca25
cf66547
06ae392
cf66547
732851f
 
 
c25ba79
2bc1a5b
 
40a6362
 
 
77fca25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""setup.py for axolotl"""

from importlib.metadata import PackageNotFoundError, version

from setuptools import find_packages, setup


def parse_requirements():
    _install_requires = []
    _dependency_links = []
    with open("./requirements.txt", encoding="utf-8") as requirements_file:
        lines = [r.strip() for r in requirements_file.readlines()]
        for line in lines:
            if line.startswith("--extra-index-url"):
                # Handle custom index URLs
                _, url = line.split()
                _dependency_links.append(url)
            elif (
                "flash-attn" not in line
                and "flash-attention" not in line
                and "deepspeed" not in line
                and line
                and line[0] != "#"
            ):
                # Handle standard packages
                _install_requires.append(line)

    try:
        torch_version = version("torch")
        if torch_version.startswith("2.1.1"):
            _install_requires.pop(_install_requires.index("xformers==0.0.22"))
            _install_requires.append("xformers==0.0.23")
    except PackageNotFoundError:
        pass

    return _install_requires, _dependency_links


install_requires, dependency_links = parse_requirements()


setup(
    name="axolotl",
    version="0.3.0",
    description="LLM Trainer",
    long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
    package_dir={"": "src"},
    packages=find_packages(),
    install_requires=install_requires,
    dependency_links=dependency_links,
    extras_require={
        "flash-attn": [
            "flash-attn==2.3.3",
        ],
        "fused-dense-lib": [
            "fused-dense-lib  @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
        ],
        "deepspeed": [
            "deepspeed",
        ],
        "mamba-ssm": [
            "mamba-ssm==1.0.1",
        ],
    },
)