g-ronimo's picture
Upload qlora.py
e1f30e6 verified
import torch, os, wandb, uuid, json
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, BitsAndBytesConfig, TrainerCallback
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset, DatasetDict, Dataset,load_from_disk
from functools import partial
set_seed(42)
accelerator = Accelerator()
run_id = str(uuid.uuid4())
modelpath="microsoft/phi-2"
dataset_name="teknium/OpenHermes-2.5"
lr=0.00002
bs=10 # batch size
bs_eval=16 # batch size for evals
ga_steps=4 # gradient acc. steps
epochs=1
max_length=1024
output_dir=f"out_{run_id}"
# Load model
model = AutoModelForCausalLM.from_pretrained(
modelpath,
device_map={"": accelerator.process_index},
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
),
torch_dtype=torch.bfloat16,
# does not work yet
# attn_implementation="flash_attention_2",
)
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False) # fast tokenizer sometimes ignores the added tokens
# Add tokens <|im_start|> and <|im_end|>, latter is special eos token,
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
model.config.eos_token_id = tokenizer.eos_token_id
# Add adapters to model
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
lora_config = LoraConfig(
r=32,
lora_alpha=32,
target_modules = [ "q_proj", "k_proj", "v_proj", "dense" ],
modules_to_save = ["lm_head", "embed_tokens"],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.config.use_cache = False
# Print stats
if accelerator.is_main_process:
model.print_trainable_parameters()
# Load dataset
with accelerator.main_process_first():
dataset = load_dataset(dataset_name)
dataset = dataset["train"].train_test_split(test_size=0.1)
# Format (chatML) and tokenize dataset
templates= {
"system": "<|im_start|>system\n{msg}<|im_end|>",
"human": "<|im_start|>user\n{msg}<|im_end|>",
"gpt": "<|im_start|>assistant\n{msg}<|im_end|>",
}
IGNORE_INDEX=-100
def tokenize(input, max_length):
input_ids, attention_mask, labels = [], [], []
for i,msg in enumerate(input["conversations"]):
msg_role=msg["from"]
msg_content=msg["value"]
isHuman=msg_role=="human"
if not msg_role in templates: return # this will break it
msg_chatml=templates[msg_role].format(msg=msg_content)
msg_tokenized=tokenizer(msg_chatml, truncation=False, add_special_tokens=False)
input_ids+=msg_tokenized["input_ids"]
attention_mask+=msg_tokenized["attention_mask"]
labels+=[IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]
return {
"input_ids": input_ids[:max_length],
"attention_mask": attention_mask[:max_length],
"labels": labels[:max_length],
}
dataset_tokenized = dataset.map(
partial(tokenize, max_length=max_length),
batched=False,
# num_proc=os.cpu_count()//accelerator.num_processes, # multithreaded
num_proc=os.cpu_count(), # multithreaded
remove_columns=dataset["train"].column_names # don't need this anymore, we have tokens from here on
)
# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
tokens=[e["input_ids"] for e in elements]
tokens_maxlen=max([len(t) for t in tokens])
for i,sample in enumerate(elements):
input_ids=sample["input_ids"]
labels=sample["labels"]
attention_mask=sample["attention_mask"]
pad_len=tokens_maxlen-len(input_ids)
input_ids.extend( pad_len * [tokenizer.pad_token_id] )
labels.extend( pad_len * [IGNORE_INDEX] )
attention_mask.extend( pad_len * [0] )
batch={
"input_ids": torch.tensor( [e["input_ids"] for e in elements] ),
"labels": torch.tensor( [e["labels"] for e in elements] ),
"attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ),
}
return batch
steps_per_epoch=len(dataset_tokenized["train"])//(accelerator.num_processes*bs*ga_steps)
args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=bs,
per_device_eval_batch_size=bs_eval,
evaluation_strategy="steps",
logging_steps=1,
eval_steps=steps_per_epoch//3, # 2 evals per epoch
save_steps=steps_per_epoch//3, # save once per epoch
gradient_accumulation_steps=ga_steps,
num_train_epochs=epochs,
lr_scheduler_type="constant",
optim="paged_adamw_32bit", # val_loss will go nan with paged_adamw_8bit
learning_rate=lr,
group_by_length=False,
bf16=True,
ddp_find_unused_parameters=False,
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=collate,
train_dataset=dataset_tokenized["train"],
eval_dataset=dataset_tokenized["test"],
)
if accelerator.is_main_process:
run = wandb.init(
project="phi2-teknium1",
name=modelpath+"_"+dataset_name+f"_bs-{bs}_LR-{lr}_GPUs-{accelerator.num_processes}_maxlen-{max_length}_{run_id}",
config={
"model_name": modelpath,
"run_id": run_id,
"dataset": dataset_name,
"output_dir": output_dir,
"lr": lr,
"max_length": max_length,
"train_batch_size": bs,
"validation_batch_size": bs,
"ga_steps": ga_steps,
"lora_config": lora_config,
"training_args": args,
"GPUs": accelerator.num_processes,
}
)
run.log_code()
trainer.train()