File size: 622 Bytes
56d31bf
 
 
 
462285a
56d31bf
da2bc01
118a2e7
b931e4e
56d31bf
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
FROM python:3.9

WORKDIR /code

COPY . /code

RUN pip install packaging ninja buildtools

RUN pip install --no-cache-dir torch==2.2.1 --index-url https://download.pytorch.org/whl/cu121

RUN pip install --no-cache-dir -r /code/requirements.txt

COPY . .

# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]

CMD ["python", "train_mamba.py", "--model", "state-spaces/mamba-130m", "--tokenizer", "EleutherAI/gpt-neox-20b", "--learning_rate", "5e-5", "--batch_size", "1", "--gradient_accumulation_steps", "1", "--optim paged_adamw_8bit", "--data_path", "./data/ultrachat_small.jsonl", "--num_epochs", "1"]