Spaces:
Build error
Build error
Pratik Dwivedi
commited on
Commit
•
56d31bf
1
Parent(s):
cb4d237
trainer commit (#1)
Browse files- Dockerfile +15 -0
- app.py +89 -0
- chat.py +31 -0
- data/ultrachat_small.jsonl +0 -0
- requirements.txt +8 -0
- scripts/download_ultrachat.py +10 -0
- train_mamba.py +62 -0
- trainer/__init__.py +0 -0
- trainer/data.py +83 -0
- trainer/mamba_trainer.py +39 -0
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
# RUN pip install --no-cache-dir torch==2.2.1 --index-url https://download.pytorch.org/whl/cu121
|
8 |
+
|
9 |
+
RUN pip install --no-cache-dir -r /code/requirements.txt
|
10 |
+
|
11 |
+
COPY . .
|
12 |
+
|
13 |
+
# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
14 |
+
|
15 |
+
CMD ["python", "train_mamba.py", "--model", "state-spaces/mamba-130m", "--tokenizer", "EleutherAI/gpt-neox-20b", "--learning_rate", "5e-5", "--batch_size", "1", "--gradient_accumulation_steps", "1", "--optim paged_adamw_8bit", "--data_path", "./data/ultrachat_small.jsonl", "--num_epochs", "1"]
|
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
5 |
+
from argparse import ArgumentParser
|
6 |
+
|
7 |
+
def get_args():
|
8 |
+
parser = ArgumentParser()
|
9 |
+
parser.add_argument("--port", type=int, default=7860)
|
10 |
+
parser.add_argument("--device", type=str, default='cuda', help='Device to run the model on')
|
11 |
+
parser.add_argument("--model", type=str, default='havenhq/mamba-chat', help='Model to use')
|
12 |
+
parser.add_argument(
|
13 |
+
"--share",
|
14 |
+
action="store_true",
|
15 |
+
default=False,
|
16 |
+
help="share your instance publicly through gradio",
|
17 |
+
)
|
18 |
+
try:
|
19 |
+
args = parser.parse_args()
|
20 |
+
except:
|
21 |
+
parser.print_help()
|
22 |
+
exit(0)
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
args = get_args()
|
28 |
+
|
29 |
+
device = args.device
|
30 |
+
model_name = args.model
|
31 |
+
eos = "<|endoftext|>"
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
tokenizer.eos_token = eos
|
34 |
+
tokenizer.pad_token = tokenizer.eos_token
|
35 |
+
tokenizer.chat_template = AutoTokenizer.from_pretrained(
|
36 |
+
"HuggingFaceH4/zephyr-7b-beta"
|
37 |
+
).chat_template
|
38 |
+
|
39 |
+
model = MambaLMHeadModel.from_pretrained(
|
40 |
+
model_name, device=device, dtype=torch.float16
|
41 |
+
)
|
42 |
+
|
43 |
+
def chat_with_mamba(
|
44 |
+
user_message,
|
45 |
+
history: list[list[str]],
|
46 |
+
temperature: float = 0.9,
|
47 |
+
top_p: float = 0.7,
|
48 |
+
max_length: int = 2000,
|
49 |
+
):
|
50 |
+
history_dict: list[dict[str, str]] = []
|
51 |
+
for user_m, assistant_m in history:
|
52 |
+
history_dict.append(dict(role="user", content=user_m))
|
53 |
+
history_dict.append(dict(role="assistant", content=assistant_m))
|
54 |
+
history_dict.append(dict(role="user", content=user_message))
|
55 |
+
|
56 |
+
input_ids = tokenizer.apply_chat_template(
|
57 |
+
history_dict, return_tensors="pt", add_generation_prompt=True
|
58 |
+
).to(device)
|
59 |
+
|
60 |
+
out = model.generate(
|
61 |
+
input_ids=input_ids,
|
62 |
+
max_length=max_length,
|
63 |
+
temperature=temperature,
|
64 |
+
top_p=top_p,
|
65 |
+
eos_token_id=tokenizer.eos_token_id,
|
66 |
+
)
|
67 |
+
|
68 |
+
decoded = tokenizer.batch_decode(out)
|
69 |
+
assistant_message = (
|
70 |
+
decoded[0].split("<|assistant|>\n")[-1].replace(eos, "")
|
71 |
+
)
|
72 |
+
return assistant_message
|
73 |
+
|
74 |
+
|
75 |
+
demo = gr.ChatInterface(
|
76 |
+
fn=chat_with_mamba,
|
77 |
+
# examples=[
|
78 |
+
# "Explain what is state space model",
|
79 |
+
# "Nice to meet you!",
|
80 |
+
# "'Mamba is way better than ChatGPT.' Is this statement correct?",
|
81 |
+
# ],
|
82 |
+
additional_inputs=[
|
83 |
+
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="temperature"),
|
84 |
+
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="top_p"),
|
85 |
+
gr.Number(value=2000, label="max_length"),
|
86 |
+
],
|
87 |
+
title="Mamba Chat",
|
88 |
+
)
|
89 |
+
demo.launch(server_port=args.port, share=args.share)
|
chat.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
4 |
+
|
5 |
+
device = "cuda"
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
|
7 |
+
tokenizer.eos_token = "<|endoftext|>"
|
8 |
+
tokenizer.pad_token = tokenizer.eos_token
|
9 |
+
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
|
10 |
+
|
11 |
+
model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)
|
12 |
+
|
13 |
+
messages = []
|
14 |
+
while True:
|
15 |
+
user_message = input("\nYour message: ")
|
16 |
+
messages.append(dict(
|
17 |
+
role="user",
|
18 |
+
content=user_message
|
19 |
+
))
|
20 |
+
|
21 |
+
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
|
22 |
+
|
23 |
+
out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
|
24 |
+
|
25 |
+
decoded = tokenizer.batch_decode(out)
|
26 |
+
messages.append(dict(
|
27 |
+
role="assistant",
|
28 |
+
content=decoded[0].split("<|assistant|>\n")[-1])
|
29 |
+
)
|
30 |
+
|
31 |
+
print("Model:", decoded[0].split("<|assistant|>\n")[-1])
|
data/ultrachat_small.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
packaging
|
2 |
+
torch==2.2.1+cu121
|
3 |
+
transformers
|
4 |
+
causal-conv1d
|
5 |
+
mamba-ssm
|
6 |
+
accelerate
|
7 |
+
bitsandbytes
|
8 |
+
scipy==1.11.4
|
scripts/download_ultrachat.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from datasets import load_dataset
|
3 |
+
|
4 |
+
data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
5 |
+
|
6 |
+
|
7 |
+
with open("../data/ultrachat.jsonl", "w") as f:
|
8 |
+
for d in data:
|
9 |
+
f.write(json.dumps(dict(messages=d["messages"]))+"\n")
|
10 |
+
|
train_mamba.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
5 |
+
from transformers import AutoTokenizer, TrainingArguments
|
6 |
+
from trainer.data import ChatDataModule
|
7 |
+
from trainer.mamba_trainer import MambaTrainer
|
8 |
+
|
9 |
+
|
10 |
+
def run(args):
|
11 |
+
|
12 |
+
print("Loading Mamba {} model".format(args.model))
|
13 |
+
model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda")
|
14 |
+
print("Loading tokenizer {}".format(args.tokenizer))
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
16 |
+
tokenizer.eos_token = "<|endoftext|>"
|
17 |
+
tokenizer.pad_token = tokenizer.eos_token
|
18 |
+
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
|
19 |
+
|
20 |
+
print("Loading data from {}".format(args.data_path))
|
21 |
+
data_module = ChatDataModule(
|
22 |
+
tokenizer=tokenizer,
|
23 |
+
data_path=args.data_path,
|
24 |
+
conversation_template=tokenizer.chat_template,
|
25 |
+
max_tokens=2048
|
26 |
+
)
|
27 |
+
|
28 |
+
print("Initializing trainer...")
|
29 |
+
trainer = MambaTrainer(
|
30 |
+
model=model,
|
31 |
+
train_dataset=data_module.dataset,
|
32 |
+
tokenizer=tokenizer,
|
33 |
+
args=TrainingArguments(
|
34 |
+
learning_rate=args.learning_rate,
|
35 |
+
num_train_epochs=args.num_epochs,
|
36 |
+
per_device_train_batch_size=args.batch_size,
|
37 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
38 |
+
optim=args.optim,
|
39 |
+
output_dir="mamba-chat",
|
40 |
+
logging_steps=50,
|
41 |
+
save_steps=500,
|
42 |
+
),
|
43 |
+
data_collator=data_module.data_collator,
|
44 |
+
)
|
45 |
+
print("Training started...")
|
46 |
+
trainer.train()
|
47 |
+
print("Training finished!")
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument("--model", type=str, default="state-spaces/mamba-130m")
|
53 |
+
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
|
54 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
55 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
56 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
57 |
+
parser.add_argument("--optim", type=str, default="adamw_torch")
|
58 |
+
parser.add_argument("--data_path", type=str, default="./data/ultrachat_small.jsonl")
|
59 |
+
parser.add_argument("--num_epochs", type=int, default=1)
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
run(args)
|
trainer/__init__.py
ADDED
File without changes
|
trainer/data.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
import json
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Dict, Sequence
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
|
11 |
+
class ChatDataset(Dataset):
|
12 |
+
def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
|
13 |
+
super(ChatDataset, self).__init__()
|
14 |
+
data = []
|
15 |
+
with open(data_path, "r") as file:
|
16 |
+
for line in file:
|
17 |
+
try:
|
18 |
+
data.append(json.loads(line))
|
19 |
+
except Exception as e:
|
20 |
+
print("json processing exception", e)
|
21 |
+
continue
|
22 |
+
|
23 |
+
|
24 |
+
data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)
|
25 |
+
|
26 |
+
self.input_ids = data_dict["input_ids"]
|
27 |
+
self.labels = data_dict["labels"]
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.input_ids)
|
31 |
+
|
32 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
33 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class DataCollatorForChatDataset(object):
|
38 |
+
"""
|
39 |
+
Collate examples for supervised fine-tuning.
|
40 |
+
"""
|
41 |
+
|
42 |
+
tokenizer: transformers.PreTrainedTokenizer
|
43 |
+
|
44 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
45 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
|
46 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
47 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
|
48 |
+
|
49 |
+
return dict(
|
50 |
+
input_ids=input_ids,
|
51 |
+
labels=labels,
|
52 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class ChatDataModule():
|
57 |
+
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):
|
58 |
+
|
59 |
+
self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
|
60 |
+
self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
|
61 |
+
|
62 |
+
|
63 |
+
def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
|
64 |
+
"""
|
65 |
+
Preprocess the data by tokenizing.
|
66 |
+
"""
|
67 |
+
all_input_ids = []
|
68 |
+
all_label_ids = []
|
69 |
+
tokenizer.use_default_system_prompt = False
|
70 |
+
|
71 |
+
print("Tokenizing dataset...")
|
72 |
+
for conv in tqdm(conversations):
|
73 |
+
current_conv = conv["messages"]
|
74 |
+
tokenized_responses = []
|
75 |
+
for msg in current_conv:
|
76 |
+
if msg["role"] == "assistant":
|
77 |
+
tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))
|
78 |
+
|
79 |
+
tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
|
80 |
+
all_input_ids.append(torch.LongTensor(tokenized_conv))
|
81 |
+
|
82 |
+
|
83 |
+
return dict(input_ids=all_input_ids, labels=all_input_ids)
|
trainer/mamba_trainer.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Trainer
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class MambaTrainer(Trainer):
|
7 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
8 |
+
input_ids = inputs.pop("input_ids")
|
9 |
+
lm_logits = model(input_ids).logits
|
10 |
+
|
11 |
+
labels = input_ids.to(lm_logits.device)
|
12 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
13 |
+
labels = labels[:, 1:].contiguous()
|
14 |
+
|
15 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
16 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
17 |
+
|
18 |
+
return lm_loss
|
19 |
+
|
20 |
+
def save_model(self, output_dir, _internal_call):
|
21 |
+
if not os.path.exists(output_dir):
|
22 |
+
os.makedirs(output_dir)
|
23 |
+
|
24 |
+
torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
|
25 |
+
self.tokenizer.save_pretrained(output_dir)
|
26 |
+
|
27 |
+
json_str = """
|
28 |
+
{
|
29 |
+
"d_model": 768,
|
30 |
+
"n_layer": 24,
|
31 |
+
"vocab_size": 50277,
|
32 |
+
"ssm_cfg": {},
|
33 |
+
"rms_norm": true,
|
34 |
+
"residual_in_fp32": true,
|
35 |
+
"fused_add_norm": true,
|
36 |
+
"pad_vocab_size_multiple": 8
|
37 |
+
}"""
|
38 |
+
with open(f"{output_dir}/config.json", 'w') as f:
|
39 |
+
f.write(json_str)
|