Spaces:
Paused
Paused
""" | |
Hold the training script for the medusa model. | |
Adapted from the original code here: https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py | |
""" | |
import os | |
from dataclasses import dataclass, field | |
import pathlib | |
from typing import Dict, Optional | |
import torch | |
from torch.utils.data import Dataset | |
import transformers | |
from transformers import Trainer, BitsAndBytesConfig | |
from transformers.trainer_pt_utils import LabelSmoother | |
from torch.nn import CrossEntropyLoss | |
from medusa.model.medusa_model import MedusaModel, MedusaConfig | |
from calibration_datasets import CalibrationDataset | |
IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
# Customized for training Medusa heads | |
class CustomizedTrainer(Trainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
""" | |
Compute the training loss for the model. | |
Args: | |
model (torch.nn.Module): The model for which to compute the loss. | |
inputs (dict): The input data, including input IDs, attention mask, and labels. | |
return_outputs (bool): Whether to return model outputs along with the loss. | |
Returns: | |
Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs. | |
""" | |
# DDP will give us model.module | |
if hasattr(model, "module"): | |
medusa = model.module.medusa | |
else: | |
medusa = model.medusa | |
logits = model( | |
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] | |
) | |
labels = inputs["labels"] | |
# Shift so that tokens < n predict n | |
loss = 0 | |
loss_fct = CrossEntropyLoss() | |
log = {} | |
for i in range(medusa): | |
medusa_logits = logits[i, :, : -(2 + i)].contiguous() | |
medusa_labels = labels[..., 2 + i :].contiguous() | |
medusa_logits = medusa_logits.view(-1, logits.shape[-1]) | |
medusa_labels = medusa_labels.view(-1) | |
medusa_labels = medusa_labels.to(medusa_logits.device) | |
loss_i = loss_fct(medusa_logits, medusa_labels) | |
loss += loss_i | |
not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID) | |
medusa_labels = medusa_labels[not_ignore] | |
# Add top-k accuracy | |
for k in range(1, 6): | |
_, topk = medusa_logits.topk(k, dim=-1) | |
topk = topk[not_ignore] | |
correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1) | |
log[f"medusa{i}_top{k}"] = correct.float().mean().item() | |
log[f"medusa{i}_loss"] = loss_i.item() | |
self.log(log) | |
return (loss, logits) if return_outputs else loss | |
class ModelArguments: | |
model_name_or_path: Optional[str] = field() | |
load_in_4bit: bool = field( | |
default=False, | |
metadata={"help": "Load in 4 bit."}, | |
) | |
load_in_8bit: bool = field( | |
default=False, | |
metadata={"help": "Load in 8 bit."}, | |
) | |
class DataArguments: | |
dataset: str = field( | |
metadata={"help": "One of the datasets names in a CalibrationDataset subclass."}, | |
) | |
class TrainingArguments(transformers.TrainingArguments): | |
cache_dir: Optional[str] = field(default=None) | |
optim: str = field(default="adamw_torch") | |
model_max_length: int = field( | |
default=2048, | |
metadata={ | |
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | |
}, | |
) | |
medusa_num_heads: int = field( | |
default=1, | |
metadata={"help": "Number of Medusa heads."}, | |
) | |
medusa_num_layers: int = field( | |
default=1, | |
metadata={"help": "Number of layers for each Medusa head."}, | |
) | |
local_rank = None | |
def rank0_print(*args): | |
if local_rank == 0: | |
print(*args) | |
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): | |
""" | |
Save the model's state dictionary to a specified directory. | |
Args: | |
trainer (transformers.Trainer): The Hugging Face Trainer object. | |
output_dir (str): The directory where the model state dictionary will be saved. | |
""" | |
state_dict = trainer.model.state_dict() | |
if trainer.args.should_save: | |
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} | |
del state_dict | |
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa | |
class SupervisedDataset(Dataset): | |
"""Dataset for supervised fine-tuning. | |
Args: | |
dataset (str): One of the datasets names in a CalibrationDataset subclass. | |
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. | |
""" | |
def __init__(self, dataset, tokenizer: transformers.PreTrainedTokenizer): | |
super(SupervisedDataset, self).__init__() | |
rank0_print("Formatting inputs...") | |
dataset_classes = CalibrationDataset.__subclasses__() | |
for dataset_class in dataset_classes: | |
if dataset_class.dataset == dataset: | |
dataset = dataset_class(num_samples=int(1e6), seqlen=tokenizer.model_max_length, tokenizer=tokenizer) | |
break | |
tokenized = dataset.tokenize_dataset() | |
self.input_ids = torch.tensor([data["input_ids"] for data in tokenized], dtype=torch.long) | |
self.attention_mask = torch.tensor([data["attention_mask"] for data in tokenized], dtype=torch.long) | |
def __len__(self): | |
return self.input_ids.shape[0] | |
def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
return dict( | |
input_ids=self.input_ids[i], | |
labels=self.input_ids[i], | |
attention_mask=self.attention_mask[i], | |
) | |
def train(): | |
global local_rank | |
parser = transformers.HfArgumentParser( | |
(ModelArguments, DataArguments, TrainingArguments) | |
) | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
local_rank = training_args.local_rank | |
config = transformers.AutoConfig.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
) | |
config.use_cache = False | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
) | |
# Load model and tokenizer | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_args.model_name_or_path, | |
config=config, | |
cache_dir=training_args.cache_dir, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.bfloat16, | |
quantization_config=quantization_config if model_args.load_in_4bit else None, | |
load_in_4bit=model_args.load_in_4bit, | |
load_in_8bit=model_args.load_in_8bit, | |
) | |
# Freeze the base model | |
for param in model.base_model.parameters(): | |
param.requires_grad = False | |
# Add Medusa heads | |
medusa_lm_head = MedusaModel( | |
model, | |
medusa_num_heads=training_args.medusa_num_heads, | |
medusa_num_layers=training_args.medusa_num_layers, | |
base_model_name_or_path=model_args.model_name_or_path, | |
) | |
# Format output dir | |
training_args.output_dir = f"{training_args.output_dir}_medusa_{model_args.model_name_or_path.split('/')[-1]}" | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
model_max_length=training_args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
tokenizer.pad_token = tokenizer.unk_token | |
# Load data | |
data_module = {"train_dataset": SupervisedDataset(data_args.dataset, tokenizer), "eval_dataset": None} | |
# Generate Medusa config for pushing to HF hub | |
medusa_config = MedusaConfig( | |
medusa_num_heads=training_args.medusa_num_heads, | |
medusa_num_layers=training_args.medusa_num_layers, | |
base_model_name_or_path=model_args.model_name_or_path, | |
) | |
# Save Medusa config | |
medusa_config.save_pretrained(training_args.output_dir) | |
# Start trainner | |
trainer = CustomizedTrainer( | |
model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module | |
) | |
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | |
trainer.train(resume_from_checkpoint=True) | |
else: | |
trainer.train() | |
model.config.use_cache = True | |
# Save MedusaHead seperately | |
if hasattr(medusa_lm_head, "module"): | |
lm_head = medusa_lm_head.module.medusa_head | |
else: | |
lm_head = medusa_lm_head.medusa_head | |
# Save Medusa heads | |
torch.save( | |
lm_head.state_dict(), | |
os.path.join(training_args.output_dir, "medusa_lm_head.pt"), | |
) | |
if __name__ == "__main__": | |
train() | |