medusa-maker / src /medusa_training_script.py
joaogante's picture
joaogante HF staff
datasets refactor
bd89ed8
raw
history blame
8.81 kB
"""
Hold the training script for the medusa model.
Adapted from the original code here: https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py
"""
import os
from dataclasses import dataclass, field
import pathlib
from typing import Dict, Optional
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer, BitsAndBytesConfig
from transformers.trainer_pt_utils import LabelSmoother
from torch.nn import CrossEntropyLoss
from medusa.model.medusa_model import MedusaModel, MedusaConfig
from calibration_datasets import CalibrationDataset
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
# Customized for training Medusa heads
class CustomizedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
"""
Compute the training loss for the model.
Args:
model (torch.nn.Module): The model for which to compute the loss.
inputs (dict): The input data, including input IDs, attention mask, and labels.
return_outputs (bool): Whether to return model outputs along with the loss.
Returns:
Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
"""
# DDP will give us model.module
if hasattr(model, "module"):
medusa = model.module.medusa
else:
medusa = model.medusa
logits = model(
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
)
labels = inputs["labels"]
# Shift so that tokens < n predict n
loss = 0
loss_fct = CrossEntropyLoss()
log = {}
for i in range(medusa):
medusa_logits = logits[i, :, : -(2 + i)].contiguous()
medusa_labels = labels[..., 2 + i :].contiguous()
medusa_logits = medusa_logits.view(-1, logits.shape[-1])
medusa_labels = medusa_labels.view(-1)
medusa_labels = medusa_labels.to(medusa_logits.device)
loss_i = loss_fct(medusa_logits, medusa_labels)
loss += loss_i
not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
medusa_labels = medusa_labels[not_ignore]
# Add top-k accuracy
for k in range(1, 6):
_, topk = medusa_logits.topk(k, dim=-1)
topk = topk[not_ignore]
correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
log[f"medusa{i}_top{k}"] = correct.float().mean().item()
log[f"medusa{i}_loss"] = loss_i.item()
self.log(log)
return (loss, logits) if return_outputs else loss
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field()
load_in_4bit: bool = field(
default=False,
metadata={"help": "Load in 4 bit."},
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Load in 8 bit."},
)
@dataclass
class DataArguments:
dataset: str = field(
metadata={"help": "One of the datasets names in a CalibrationDataset subclass."},
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
medusa_num_heads: int = field(
default=1,
metadata={"help": "Number of Medusa heads."},
)
medusa_num_layers: int = field(
default=1,
metadata={"help": "Number of layers for each Medusa head."},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""
Save the model's state dictionary to a specified directory.
Args:
trainer (transformers.Trainer): The Hugging Face Trainer object.
output_dir (str): The directory where the model state dictionary will be saved.
"""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.
Args:
dataset (str): One of the datasets names in a CalibrationDataset subclass.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
"""
def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
dataset_classes = CalibrationDataset.__subclasses__()
for dataset_class in dataset_classes:
if dataset_class.dataset == dataset:
dataset = dataset_class(num_samples=int(1e6), seqlen=tokenizer.model_max_length, tokenizer=tokenizer)
break
tokenized = dataset.tokenize_dataset()
self.input_ids = torch.tensor([data["input_ids"] for data in tokenized], dtype=torch.long)
self.attention_mask = torch.tensor([data["attention_mask"] for data in tokenized], dtype=torch.long)
def __len__(self):
return self.input_ids.shape[0]
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.input_ids[i],
attention_mask=self.attention_mask[i],
)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
config.use_cache = False
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config if model_args.load_in_4bit else None,
load_in_4bit=model_args.load_in_4bit,
load_in_8bit=model_args.load_in_8bit,
)
# Freeze the base model
for param in model.base_model.parameters():
param.requires_grad = False
# Add Medusa heads
medusa_lm_head = MedusaModel(
model,
medusa_num_heads=training_args.medusa_num_heads,
medusa_num_layers=training_args.medusa_num_layers,
base_model_name_or_path=model_args.model_name_or_path,
)
# Format output dir
training_args.output_dir = f"{training_args.output_dir}_medusa_{model_args.model_name_or_path.split('/')[-1]}"
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
# Load data
data_module = {"train_dataset": SupervisedDataset(data_args.dataset, tokenizer), "eval_dataset": None}
# Generate Medusa config for pushing to HF hub
medusa_config = MedusaConfig(
medusa_num_heads=training_args.medusa_num_heads,
medusa_num_layers=training_args.medusa_num_layers,
base_model_name_or_path=model_args.model_name_or_path,
)
# Save Medusa config
medusa_config.save_pretrained(training_args.output_dir)
# Start trainner
trainer = CustomizedTrainer(
model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
model.config.use_cache = True
# Save MedusaHead seperately
if hasattr(medusa_lm_head, "module"):
lm_head = medusa_lm_head.module.medusa_head
else:
lm_head = medusa_lm_head.medusa_head
# Save Medusa heads
torch.save(
lm_head.state_dict(),
os.path.join(training_args.output_dir, "medusa_lm_head.pt"),
)
if __name__ == "__main__":
train()