File size: 3,822 Bytes
14ebd2e
 
 
1d5ab84
14ebd2e
2734e3f
14ebd2e
 
 
 
2734e3f
48612f8
2734e3f
14ebd2e
 
 
3f11b47
14ebd2e
 
 
 
 
 
 
 
 
 
 
 
2734e3f
ece46b2
48612f8
2734e3f
 
 
 
 
 
 
 
 
 
1d5ab84
 
 
8d6a289
1d5ab84
8d6a289
1d5ab84
8d6a289
2734e3f
 
 
 
c43c5c8
 
2734e3f
 
 
 
 
 
312b8d5
 
 
48612f8
 
312b8d5
 
 
48612f8
312b8d5
 
2734e3f
 
21b7439
 
 
 
 
 
809cceb
 
 
a798ba1
2734e3f
312b8d5
a798ba1
2734e3f
1d5ab84
1fc9b44
 
1d5ab84
2734e3f
809cceb
00323f0
21b7439
48612f8
990bec6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4

FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder

ENV PATH="/root/miniconda3/bin:${PATH}"

ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG CUDA="118"

ENV PYTHON_VERSION=$PYTHON_VERSION

RUN apt-get update
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*

RUN wget \
    https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
    && mkdir /root/.conda \
    && bash Miniconda3-latest-Linux-x86_64.sh -b \
    && rm -f Miniconda3-latest-Linux-x86_64.sh

RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"

ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"

WORKDIR /workspace

RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
    python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA


FROM base-builder AS flash-attn-builder

WORKDIR /workspace

ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"

RUN git clone https://github.com/HazyResearch/flash-attention.git && \
    cd flash-attention && \
    python3 setup.py bdist_wheel && \
    cd csrc/fused_dense_lib && \
    python3 setup.py bdist_wheel && \
    cd ../xentropy && \
    python3 setup.py bdist_wheel && \
    cd ../rotary && \
    python3 setup.py bdist_wheel && \
    cd ../layer_norm && \
    python3 setup.py bdist_wheel

FROM base-builder AS deepspeed-builder

ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"

WORKDIR /workspace

RUN git clone https://github.com/microsoft/DeepSpeed.git && \
    cd DeepSpeed && \
    MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel

FROM base-builder AS bnb-builder

WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA

RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
    cd bitsandbytes && \
    CUDA_VERSION=$CUDA make cuda11x && \
    python setup.py bdist_wheel

FROM base-builder

# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
#  `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .

RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels

RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
    # The base image ships with `pydantic==1.8.2` which is not working
    pip3 install -U --no-cache-dir pydantic