Spaces:
Build error
Build error
import torch | |
import argparse | |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
from transformers import AutoTokenizer, TrainingArguments | |
from trainer.data import ChatDataModule | |
from trainer.mamba_trainer import MambaTrainer | |
def run(args): | |
print("Loading Mamba {} model".format(args.model)) | |
model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda") | |
print("Loading tokenizer {}".format(args.tokenizer)) | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
tokenizer.eos_token = "<|endoftext|>" | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template | |
print("Loading data from {}".format(args.data_path)) | |
data_module = ChatDataModule( | |
tokenizer=tokenizer, | |
data_path=args.data_path, | |
conversation_template=tokenizer.chat_template, | |
max_tokens=2048 | |
) | |
print("Initializing trainer...") | |
trainer = MambaTrainer( | |
model=model, | |
train_dataset=data_module.dataset, | |
tokenizer=tokenizer, | |
args=TrainingArguments( | |
learning_rate=args.learning_rate, | |
num_train_epochs=args.num_epochs, | |
per_device_train_batch_size=args.batch_size, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
optim=args.optim, | |
output_dir="mamba-chat", | |
logging_steps=50, | |
save_steps=500, | |
), | |
data_collator=data_module.data_collator, | |
) | |
print("Training started...") | |
trainer.train() | |
print("Training finished!") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", type=str, default="state-spaces/mamba-130m") | |
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b") | |
parser.add_argument("--learning_rate", type=float, default=5e-5) | |
parser.add_argument("--batch_size", type=int, default=4) | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
parser.add_argument("--optim", type=str, default="adamw_torch") | |
parser.add_argument("--data_path", type=str, default="./data/ultrachat_small.jsonl") | |
parser.add_argument("--num_epochs", type=int, default=1) | |
args = parser.parse_args() | |
run(args) | |