Alpay Ariyak
Changed Bench Eval to report metrics correctly by split. Added total accuracy and renamed previously used bench_accuracy to bench_average_accuracy. (#512)
42f9642
unverified
"""Callbacks for Trainer class""" | |
from __future__ import annotations | |
import logging | |
import os | |
from typing import TYPE_CHECKING, Dict, List | |
import evaluate | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.distributed as dist | |
from datasets import load_dataset | |
from optimum.bettertransformer import BetterTransformer | |
from tqdm import tqdm | |
from transformers import ( | |
TrainerCallback, | |
TrainerControl, | |
TrainerState, | |
TrainingArguments, | |
) | |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy | |
from axolotl.utils.bench import log_gpu_memory_usage | |
from axolotl.utils.distributed import ( | |
barrier, | |
gather_scalar_from_all_ranks, | |
get_world_size, | |
is_main_process, | |
zero_first, | |
) | |
if TYPE_CHECKING: | |
from axolotl.utils.trainer import AxolotlTrainingArguments | |
LOG = logging.getLogger("axolotl.callbacks") | |
IGNORE_INDEX = -100 | |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods | |
"""Callback to save the PEFT adapter""" | |
def on_save( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
checkpoint_folder = os.path.join( | |
args.output_dir, | |
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", | |
) | |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") | |
kwargs["model"].save_pretrained( | |
peft_model_path, save_safetensors=args.save_safetensors | |
) | |
return control | |
class SaveBetterTransformerModelCallback( | |
TrainerCallback | |
): # pylint: disable=too-few-public-methods | |
"""Callback to save the BetterTransformer wrapped model""" | |
def on_step_end( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
# Save | |
if ( | |
args.save_strategy == IntervalStrategy.STEPS | |
and args.save_steps > 0 | |
and state.global_step % args.save_steps == 0 | |
): | |
control.should_save = True | |
if control.should_save: | |
checkpoint_folder = os.path.join( | |
args.output_dir, | |
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", | |
) | |
model = BetterTransformer.reverse(kwargs["model"]) | |
model.save_pretrained(checkpoint_folder) | |
# FIXME - need to cleanup old checkpoints | |
# since we're saving here, we don't need the trainer loop to attempt to save too b/c | |
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model | |
control.should_save = False | |
return control | |
class GPUStatsCallback( | |
TrainerCallback | |
): # pylint: disable=too-few-public-methods disable=unused-argument | |
"""Callback to track GPU utilization""" | |
def __init__(self, cfg): | |
self.cfg = cfg | |
self.logged = False | |
def on_step_end( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
if not self.logged and state.global_step > 1: | |
log_gpu_memory_usage(LOG, "while training", self.cfg.device) | |
self.logged = True | |
return control | |
def bench_eval_callback_factory(trainer, tokenizer): | |
accuracy = evaluate.load("accuracy") | |
abcd_idx = [ | |
tokenizer("A", add_special_tokens=False).input_ids[0], | |
tokenizer("B", add_special_tokens=False).input_ids[0], | |
tokenizer("C", add_special_tokens=False).input_ids[0], | |
tokenizer("D", add_special_tokens=False).input_ids[0], | |
tokenizer("E", add_special_tokens=False).input_ids[0], | |
tokenizer("F", add_special_tokens=False).input_ids[0], | |
tokenizer("G", add_special_tokens=False).input_ids[0], | |
] | |
bench_split = "eval" | |
def transform_bench_subject(example): | |
# Split on ':' and trim whitespace | |
parts = example["subject"].split(":") | |
first_part = ( | |
parts[0].strip().lower().replace("-", "_") | |
) # Lowercase the first part | |
second_part = ( | |
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all" | |
) # Replace hyphens with underscores | |
# Return the transformed values | |
return {"name": first_part, "subject": second_part} | |
if trainer.args.bench_dataset == "mmlu-zs": | |
bench_dataset = load_dataset( | |
"openaccess-ai-collective/mmlu-evals", | |
data_files={ | |
"eval": "zero_shot_mmlu_val.json", | |
"test": "zero_shot_mmlu_test.json", | |
}, | |
) | |
# bench_dataset = bench_dataset.remove_columns("subject") | |
# MMLU Five-shot (Eval/Test only) | |
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]: | |
bench_dataset = load_dataset( | |
"openaccess-ai-collective/mmlu-evals", | |
data_files={ | |
"eval": "five_shot_mmlu_val.json", | |
"test": "five_shot_mmlu_test.json", | |
}, | |
) | |
# bench_dataset = bench_dataset.remove_columns('subject') | |
elif "/" in trainer.args.bench_dataset: | |
bench_ds = trainer.args.bench_dataset | |
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2]) | |
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:]) | |
bench_dataset = load_dataset( | |
bench_ds_name, | |
data_files={ | |
"eval": bench_ds_data_file, | |
}, | |
) | |
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject) | |
else: | |
raise ValueError( | |
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args" | |
) | |
bench_dataset = bench_dataset[trainer.args.bench_split] | |
if trainer.args.max_bench_samples is not None: | |
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples)) | |
def tokenize_evals(example): | |
source = f"{tokenizer.bos_token}{example['input']}" | |
target = f"{example['output']}{tokenizer.eos_token}" | |
tokenized_source = tokenizer( | |
source, | |
max_length=2048, | |
truncation=True, | |
add_special_tokens=False, | |
) | |
tokenized_target = tokenizer( | |
target, | |
max_length=2048, | |
truncation=True, | |
add_special_tokens=False, | |
) | |
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] | |
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[ | |
"input_ids" | |
] | |
return { | |
"input_ids": input_ids, | |
"labels": labels, | |
"subject": example["subject"], | |
} | |
with zero_first(is_main_process()): | |
bench_dataset = bench_dataset.map(tokenize_evals) | |
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) | |
class BenchEvalCallback(TrainerCallback): | |
""" | |
TrainerCallback that runs the MMLU evals | |
""" | |
def on_evaluate( | |
self, | |
args: AxolotlTrainingArguments, | |
state: TrainerState, # pylint: disable=unused-argument | |
control: TrainerControl, # pylint: disable=unused-argument | |
metrics: Dict[str, float], # pylint: disable=unused-argument | |
**kwargs, # pylint: disable=unused-argument | |
): | |
data_loader = trainer.get_bench_dataloader( | |
bench_dataset.remove_columns(["input", "subject", "output", "name"]) | |
) | |
trainer.model.eval() | |
preds, refs = [], [] | |
loss_bench = 0 | |
for batch in tqdm(data_loader, total=len(data_loader)): | |
(loss, logits, labels) = trainer.prediction_step( | |
trainer.model, | |
batch, | |
prediction_loss_only=False, | |
) | |
# There are two tokens, the output, and eos token. | |
for i, logit in enumerate(logits): | |
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[ | |
0 | |
][0] | |
logit_abcd = logit[label_non_zero_id - 1][abcd_idx] | |
preds.append(torch.argmax(logit_abcd).item()) | |
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] | |
refs += [ | |
abcd_idx.index(label) if label in abcd_idx else -1 | |
for label in labels.tolist() | |
] | |
loss_bench += loss.item() | |
# Extract results by subject. | |
bench_name = bench_dataset["name"] | |
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} | |
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name | |
bench_names[s]["preds"].append(p) | |
bench_names[s]["refs"].append(r) | |
barrier() | |
local_bench_names = bench_names | |
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] | |
# Gather results from all GPUs to GPU 0 | |
loss_bench_ranks = gather_scalar_from_all_ranks( | |
lambda: loss_bench, get_world_size() | |
) | |
len_data_loader_ranks = gather_scalar_from_all_ranks( | |
lambda: len(data_loader), get_world_size() | |
) | |
if not is_main_process(): | |
dist.gather_object(local_bench_names, dst=0) | |
else: | |
dist.gather_object(local_bench_names, gathered_bench_names, dst=0) | |
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) | |
results = {f"{bench_split}_bench_loss": bench_loss} | |
# Combine results from all GPUs | |
combined_bench_names: Dict[str, Dict[str, List]] = {} | |
for bench_name in gathered_bench_names: | |
for name, data in bench_name.items(): | |
if name not in combined_bench_names: | |
combined_bench_names[name] = {"refs": [], "preds": []} | |
combined_bench_names[name]["refs"].extend(data["refs"]) | |
combined_bench_names[name]["preds"].extend(data["preds"]) | |
bench_scores = [] | |
bench_refs = [] | |
bench_preds = [] | |
for ( | |
bench_name | |
) in combined_bench_names: # pylint: disable=consider-using-dict-items | |
bench_score = accuracy.compute( | |
references=combined_bench_names[bench_name]["refs"], | |
predictions=combined_bench_names[bench_name]["preds"], | |
)["accuracy"] | |
bench_refs.extend(combined_bench_names[bench_name]["refs"]) | |
bench_preds.extend(combined_bench_names[bench_name]["preds"]) | |
if not pd.isna(bench_score): | |
results[ | |
f"{bench_split}_bench_accuracy_{bench_name}" | |
] = bench_score | |
bench_scores.append(bench_score) | |
else: | |
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0 | |
bench_scores.append(0.0) | |
results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores) | |
results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute( | |
references=bench_refs, predictions=bench_preds | |
)["accuracy"] | |
trainer.log(results) | |
return BenchEvalCallback | |