Pratik Dwivedi commited on
Commit
56d31bf
1 Parent(s): cb4d237

trainer commit (#1)

Browse files
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . .
6
+
7
+ # RUN pip install --no-cache-dir torch==2.2.1 --index-url https://download.pytorch.org/whl/cu121
8
+
9
+ RUN pip install --no-cache-dir -r /code/requirements.txt
10
+
11
+ COPY . .
12
+
13
+ # CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
14
+
15
+ CMD ["python", "train_mamba.py", "--model", "state-spaces/mamba-130m", "--tokenizer", "EleutherAI/gpt-neox-20b", "--learning_rate", "5e-5", "--batch_size", "1", "--gradient_accumulation_steps", "1", "--optim paged_adamw_8bit", "--data_path", "./data/ultrachat_small.jsonl", "--num_epochs", "1"]
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
5
+ from argparse import ArgumentParser
6
+
7
+ def get_args():
8
+ parser = ArgumentParser()
9
+ parser.add_argument("--port", type=int, default=7860)
10
+ parser.add_argument("--device", type=str, default='cuda', help='Device to run the model on')
11
+ parser.add_argument("--model", type=str, default='havenhq/mamba-chat', help='Model to use')
12
+ parser.add_argument(
13
+ "--share",
14
+ action="store_true",
15
+ default=False,
16
+ help="share your instance publicly through gradio",
17
+ )
18
+ try:
19
+ args = parser.parse_args()
20
+ except:
21
+ parser.print_help()
22
+ exit(0)
23
+ return args
24
+
25
+
26
+ if __name__ == "__main__":
27
+ args = get_args()
28
+
29
+ device = args.device
30
+ model_name = args.model
31
+ eos = "<|endoftext|>"
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+ tokenizer.eos_token = eos
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+ tokenizer.chat_template = AutoTokenizer.from_pretrained(
36
+ "HuggingFaceH4/zephyr-7b-beta"
37
+ ).chat_template
38
+
39
+ model = MambaLMHeadModel.from_pretrained(
40
+ model_name, device=device, dtype=torch.float16
41
+ )
42
+
43
+ def chat_with_mamba(
44
+ user_message,
45
+ history: list[list[str]],
46
+ temperature: float = 0.9,
47
+ top_p: float = 0.7,
48
+ max_length: int = 2000,
49
+ ):
50
+ history_dict: list[dict[str, str]] = []
51
+ for user_m, assistant_m in history:
52
+ history_dict.append(dict(role="user", content=user_m))
53
+ history_dict.append(dict(role="assistant", content=assistant_m))
54
+ history_dict.append(dict(role="user", content=user_message))
55
+
56
+ input_ids = tokenizer.apply_chat_template(
57
+ history_dict, return_tensors="pt", add_generation_prompt=True
58
+ ).to(device)
59
+
60
+ out = model.generate(
61
+ input_ids=input_ids,
62
+ max_length=max_length,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ )
67
+
68
+ decoded = tokenizer.batch_decode(out)
69
+ assistant_message = (
70
+ decoded[0].split("<|assistant|>\n")[-1].replace(eos, "")
71
+ )
72
+ return assistant_message
73
+
74
+
75
+ demo = gr.ChatInterface(
76
+ fn=chat_with_mamba,
77
+ # examples=[
78
+ # "Explain what is state space model",
79
+ # "Nice to meet you!",
80
+ # "'Mamba is way better than ChatGPT.' Is this statement correct?",
81
+ # ],
82
+ additional_inputs=[
83
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="temperature"),
84
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="top_p"),
85
+ gr.Number(value=2000, label="max_length"),
86
+ ],
87
+ title="Mamba Chat",
88
+ )
89
+ demo.launch(server_port=args.port, share=args.share)
chat.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4
+
5
+ device = "cuda"
6
+ tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
7
+ tokenizer.eos_token = "<|endoftext|>"
8
+ tokenizer.pad_token = tokenizer.eos_token
9
+ tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
10
+
11
+ model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)
12
+
13
+ messages = []
14
+ while True:
15
+ user_message = input("\nYour message: ")
16
+ messages.append(dict(
17
+ role="user",
18
+ content=user_message
19
+ ))
20
+
21
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
22
+
23
+ out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
24
+
25
+ decoded = tokenizer.batch_decode(out)
26
+ messages.append(dict(
27
+ role="assistant",
28
+ content=decoded[0].split("<|assistant|>\n")[-1])
29
+ )
30
+
31
+ print("Model:", decoded[0].split("<|assistant|>\n")[-1])
data/ultrachat_small.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ packaging
2
+ torch==2.2.1+cu121
3
+ transformers
4
+ causal-conv1d
5
+ mamba-ssm
6
+ accelerate
7
+ bitsandbytes
8
+ scipy==1.11.4
scripts/download_ultrachat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datasets import load_dataset
3
+
4
+ data = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
5
+
6
+
7
+ with open("../data/ultrachat.jsonl", "w") as f:
8
+ for d in data:
9
+ f.write(json.dumps(dict(messages=d["messages"]))+"\n")
10
+
train_mamba.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+
4
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
5
+ from transformers import AutoTokenizer, TrainingArguments
6
+ from trainer.data import ChatDataModule
7
+ from trainer.mamba_trainer import MambaTrainer
8
+
9
+
10
+ def run(args):
11
+
12
+ print("Loading Mamba {} model".format(args.model))
13
+ model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda")
14
+ print("Loading tokenizer {}".format(args.tokenizer))
15
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
16
+ tokenizer.eos_token = "<|endoftext|>"
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
19
+
20
+ print("Loading data from {}".format(args.data_path))
21
+ data_module = ChatDataModule(
22
+ tokenizer=tokenizer,
23
+ data_path=args.data_path,
24
+ conversation_template=tokenizer.chat_template,
25
+ max_tokens=2048
26
+ )
27
+
28
+ print("Initializing trainer...")
29
+ trainer = MambaTrainer(
30
+ model=model,
31
+ train_dataset=data_module.dataset,
32
+ tokenizer=tokenizer,
33
+ args=TrainingArguments(
34
+ learning_rate=args.learning_rate,
35
+ num_train_epochs=args.num_epochs,
36
+ per_device_train_batch_size=args.batch_size,
37
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
38
+ optim=args.optim,
39
+ output_dir="mamba-chat",
40
+ logging_steps=50,
41
+ save_steps=500,
42
+ ),
43
+ data_collator=data_module.data_collator,
44
+ )
45
+ print("Training started...")
46
+ trainer.train()
47
+ print("Training finished!")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--model", type=str, default="state-spaces/mamba-130m")
53
+ parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
54
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
55
+ parser.add_argument("--batch_size", type=int, default=4)
56
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
57
+ parser.add_argument("--optim", type=str, default="adamw_torch")
58
+ parser.add_argument("--data_path", type=str, default="./data/ultrachat_small.jsonl")
59
+ parser.add_argument("--num_epochs", type=int, default=1)
60
+ args = parser.parse_args()
61
+
62
+ run(args)
trainer/__init__.py ADDED
File without changes
trainer/data.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import json
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Sequence
7
+ from tqdm import tqdm
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class ChatDataset(Dataset):
12
+ def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
13
+ super(ChatDataset, self).__init__()
14
+ data = []
15
+ with open(data_path, "r") as file:
16
+ for line in file:
17
+ try:
18
+ data.append(json.loads(line))
19
+ except Exception as e:
20
+ print("json processing exception", e)
21
+ continue
22
+
23
+
24
+ data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)
25
+
26
+ self.input_ids = data_dict["input_ids"]
27
+ self.labels = data_dict["labels"]
28
+
29
+ def __len__(self):
30
+ return len(self.input_ids)
31
+
32
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
33
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
34
+
35
+
36
+ @dataclass
37
+ class DataCollatorForChatDataset(object):
38
+ """
39
+ Collate examples for supervised fine-tuning.
40
+ """
41
+
42
+ tokenizer: transformers.PreTrainedTokenizer
43
+
44
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
45
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
46
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
47
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
48
+
49
+ return dict(
50
+ input_ids=input_ids,
51
+ labels=labels,
52
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
53
+ )
54
+
55
+
56
+ class ChatDataModule():
57
+ def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):
58
+
59
+ self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
60
+ self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
61
+
62
+
63
+ def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
64
+ """
65
+ Preprocess the data by tokenizing.
66
+ """
67
+ all_input_ids = []
68
+ all_label_ids = []
69
+ tokenizer.use_default_system_prompt = False
70
+
71
+ print("Tokenizing dataset...")
72
+ for conv in tqdm(conversations):
73
+ current_conv = conv["messages"]
74
+ tokenized_responses = []
75
+ for msg in current_conv:
76
+ if msg["role"] == "assistant":
77
+ tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))
78
+
79
+ tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
80
+ all_input_ids.append(torch.LongTensor(tokenized_conv))
81
+
82
+
83
+ return dict(input_ids=all_input_ids, labels=all_input_ids)
trainer/mamba_trainer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer
2
+ import torch
3
+ import os
4
+
5
+
6
+ class MambaTrainer(Trainer):
7
+ def compute_loss(self, model, inputs, return_outputs=False):
8
+ input_ids = inputs.pop("input_ids")
9
+ lm_logits = model(input_ids).logits
10
+
11
+ labels = input_ids.to(lm_logits.device)
12
+ shift_logits = lm_logits[:, :-1, :].contiguous()
13
+ labels = labels[:, 1:].contiguous()
14
+
15
+ loss_fct = torch.nn.CrossEntropyLoss()
16
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
17
+
18
+ return lm_loss
19
+
20
+ def save_model(self, output_dir, _internal_call):
21
+ if not os.path.exists(output_dir):
22
+ os.makedirs(output_dir)
23
+
24
+ torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
25
+ self.tokenizer.save_pretrained(output_dir)
26
+
27
+ json_str = """
28
+ {
29
+ "d_model": 768,
30
+ "n_layer": 24,
31
+ "vocab_size": 50277,
32
+ "ssm_cfg": {},
33
+ "rms_norm": true,
34
+ "residual_in_fp32": true,
35
+ "fused_add_norm": true,
36
+ "pad_vocab_size_multiple": 8
37
+ }"""
38
+ with open(f"{output_dir}/config.json", 'w') as f:
39
+ f.write(json_str)