File size: 30,759 Bytes
5062eca
 
7657632
 
e303d64
0d6708b
4f4d638
 
7657632
0d6708b
7657632
b8e5603
7657632
 
 
 
5b67ea9
7657632
1210dc8
7657632
2bc1a5b
5b67ea9
 
2bc1a5b
 
37293dc
 
2bc1a5b
1a82082
0d6708b
e303d64
7657632
 
e30f1e3
7657632
 
09f1543
7657632
 
 
 
 
6c81c61
e303d64
 
7657632
e303d64
2bc1a5b
2844eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1210dc8
 
 
1a82082
1210dc8
1a82082
1210dc8
 
 
 
 
 
1a82082
 
 
 
 
 
 
 
 
 
 
 
 
1210dc8
1a82082
 
ab5cd28
1210dc8
1a82082
 
 
1210dc8
e303d64
 
7b55fe6
e303d64
 
7b55fe6
e303d64
 
 
 
 
 
 
 
 
 
 
 
7b55fe6
e303d64
 
 
7657632
 
58ec8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7657632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30f1e3
09f1543
7657632
 
09f1543
 
 
 
7657632
42f9642
7657632
 
 
 
 
 
 
 
 
 
 
42f9642
 
7657632
 
 
 
 
 
 
42f9642
 
7657632
 
42f9642
7657632
 
 
42f9642
7657632
42f9642
 
 
 
7657632
 
e30f1e3
 
 
 
7657632
5b67ea9
 
5a5d474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b67ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a5d474
5b67ea9
 
 
 
 
 
 
 
 
 
 
bf08044
5b67ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d870b
5b67ea9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d870b
5b67ea9
 
 
 
 
 
 
490923f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f4d638
 
 
 
 
 
 
 
 
490923f
 
 
b8e5603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
"""Callbacks for Trainer class"""

from __future__ import annotations

import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List

import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
from transformers import (
    GenerationConfig,
    Trainer,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
    barrier,
    broadcast_dict,
    gather_scalar_from_all_ranks,
    get_world_size,
    is_distributed,
    is_main_process,
    zero_first,
)

if TYPE_CHECKING:
    from axolotl.core.trainer_builder import AxolotlTrainingArguments

LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100


class EvalFirstStepCallback(
    TrainerCallback
):  # pylint: disable=too-few-public-methods disable=unused-argument
    """
    Callback to trigger evals on the first step
    """

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if (
            args.evaluation_strategy == IntervalStrategy.STEPS
            and state.global_step == 1
        ):
            control.should_evaluate = True
        return control


class SaveBetterTransformerModelCallback(
    TrainerCallback
):  # pylint: disable=too-few-public-methods
    """Callback to save the BetterTransformer wrapped model"""

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        # Save
        if (
            args.save_strategy == IntervalStrategy.STEPS
            and args.save_steps > 0
            and state.global_step % args.save_steps == 0
        ):
            control.should_save = True

        if control.should_save:
            checkpoint_folder = os.path.join(
                args.output_dir,
                f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
            )

            model = BetterTransformer.reverse(kwargs["model"])
            model.save_pretrained(checkpoint_folder)
            # FIXME - need to cleanup old checkpoints

            # since we're saving here, we don't need the trainer loop to attempt to save too b/c
            # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
            control.should_save = False
        return control


class GPUStatsCallback(
    TrainerCallback
):  # pylint: disable=too-few-public-methods disable=unused-argument
    """Callback to track GPU utilization"""

    def __init__(self, cfg):
        self.cfg = cfg
        self.logged = False

    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if not self.logged and state.global_step > 1:
            log_gpu_memory_usage(LOG, "while training", self.cfg.device)
            self.logged = True
        return control


class LossWatchDogCallback(TrainerCallback):
    """Callback to track loss and stop training if loss is too high"""

    def __init__(self, cfg):
        self.cfg = cfg
        self.logged = False
        self.violations = 0
        self.threshold = cfg.loss_watchdog_threshold
        self.patience = cfg.loss_watchdog_patience or 3

    def on_step_end(
        self,
        _args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **_kwargs,
    ):
        if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
            if state.log_history[-1]["loss"] > self.threshold:
                self.violations += 1
                if self.violations >= self.patience:
                    LOG.warning(
                        "Loss is too high, stopping training (loss_watchdog_threshold)"
                    )
                    control.should_training_stop = True
            else:
                self.violations = 0
        return control


def bench_eval_callback_factory(trainer, tokenizer):
    accuracy = evaluate.load("accuracy")
    abcd_idx = [
        tokenizer("A", add_special_tokens=False).input_ids[0],
        tokenizer("B", add_special_tokens=False).input_ids[0],
        tokenizer("C", add_special_tokens=False).input_ids[0],
        tokenizer("D", add_special_tokens=False).input_ids[0],
        tokenizer("E", add_special_tokens=False).input_ids[0],
        tokenizer("F", add_special_tokens=False).input_ids[0],
        tokenizer("G", add_special_tokens=False).input_ids[0],
    ]
    bench_split = "eval"

    def transform_bench_subject(example):
        # Split on ':' and trim whitespace
        parts = example["subject"].split(":")
        first_part = (
            parts[0].strip().lower().replace("-", "_")
        )  # Lowercase the first part
        second_part = (
            parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
        )  # Replace hyphens with underscores

        # Return the transformed values
        return {"name": first_part, "subject": second_part}

    if trainer.args.bench_dataset == "mmlu-zs":
        bench_dataset = load_dataset(
            "openaccess-ai-collective/mmlu-evals",
            data_files={
                "eval": "zero_shot_mmlu_val.json",
                "test": "zero_shot_mmlu_test.json",
            },
        )
        # bench_dataset = bench_dataset.remove_columns("subject")
    # MMLU Five-shot (Eval/Test only)
    elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
        bench_dataset = load_dataset(
            "openaccess-ai-collective/mmlu-evals",
            data_files={
                "eval": "five_shot_mmlu_val.json",
                "test": "five_shot_mmlu_test.json",
            },
        )
        # bench_dataset = bench_dataset.remove_columns('subject')
    elif "/" in trainer.args.bench_dataset:
        bench_ds = trainer.args.bench_dataset
        bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
        bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
        bench_dataset = load_dataset(
            bench_ds_name,
            data_files={
                "eval": bench_ds_data_file,
            },
        )
        bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
    else:
        raise ValueError(
            f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
        )
    bench_dataset = bench_dataset[trainer.args.bench_split]
    if trainer.args.max_bench_samples is not None:
        bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))

    def tokenize_evals(example):
        source = f"{tokenizer.bos_token}{example['input']}"
        target = f"{example['output']}{tokenizer.eos_token}"

        tokenized_source = tokenizer(
            source,
            max_length=2048,
            truncation=True,
            add_special_tokens=False,
        )
        tokenized_target = tokenizer(
            target,
            max_length=2048,
            truncation=True,
            add_special_tokens=False,
        )
        input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
        labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
            "input_ids"
        ]

        return {
            "input_ids": input_ids,
            "labels": labels,
            "subject": example["subject"],
        }

    with zero_first(is_main_process()):
        bench_dataset = bench_dataset.map(tokenize_evals)
        bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)

    class BenchEvalCallback(TrainerCallback):
        """
        TrainerCallback that runs the MMLU evals
        """

        def on_evaluate(
            self,
            args: AxolotlTrainingArguments,
            state: TrainerState,  # pylint: disable=unused-argument
            control: TrainerControl,  # pylint: disable=unused-argument
            metrics: Dict[str, float],  # pylint: disable=unused-argument
            **kwargs,  # pylint: disable=unused-argument
        ):
            data_loader = trainer.get_bench_dataloader(
                bench_dataset.remove_columns(["input", "subject", "output", "name"])
            )
            trainer.model.eval()
            preds, refs = [], []
            loss_bench = 0
            for batch in tqdm(data_loader, total=len(data_loader)):
                (loss, logits, labels) = trainer.prediction_step(
                    trainer.model,
                    batch,
                    prediction_loss_only=False,
                )
                # There are two tokens, the output, and eos token.
                for i, logit in enumerate(logits):
                    label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
                        0
                    ][0]
                    logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
                    preds.append(torch.argmax(logit_abcd).item())
                labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
                refs += [
                    abcd_idx.index(label) if label in abcd_idx else -1
                    for label in labels.tolist()
                ]
                loss_bench += loss.item()
            # Extract results by subject.
            bench_name = bench_dataset["name"]
            bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
            for s, p, r in zip(bench_name, preds, refs):  # pylint: disable=invalid-name
                bench_names[s]["preds"].append(p)
                bench_names[s]["refs"].append(r)
            barrier()
            local_bench_names = bench_names
            gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
            # Gather results from all GPUs to GPU 0

            loss_bench_ranks = gather_scalar_from_all_ranks(
                lambda: loss_bench, get_world_size()
            )
            len_data_loader_ranks = gather_scalar_from_all_ranks(
                lambda: len(data_loader), get_world_size()
            )

            results = {}
            if is_distributed() and not is_main_process():
                dist.gather_object(local_bench_names, dst=0)
            else:
                if is_distributed():
                    dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
                else:
                    gathered_bench_names = [local_bench_names]
                bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
                results = {f"{bench_split}_bench_loss": bench_loss}

                # Combine results from all GPUs
                combined_bench_names: Dict[str, Dict[str, List]] = {}
                for bench_name in gathered_bench_names:
                    for name, data in bench_name.items():
                        if name not in combined_bench_names:
                            combined_bench_names[name] = {"refs": [], "preds": []}
                        combined_bench_names[name]["refs"].extend(data["refs"])
                        combined_bench_names[name]["preds"].extend(data["preds"])

                bench_scores = []
                bench_refs = []
                bench_preds = []
                for (
                    bench_name
                ) in combined_bench_names:  # pylint: disable=consider-using-dict-items
                    bench_score = accuracy.compute(
                        references=combined_bench_names[bench_name]["refs"],
                        predictions=combined_bench_names[bench_name]["preds"],
                    )["accuracy"]
                    bench_refs.extend(combined_bench_names[bench_name]["refs"])
                    bench_preds.extend(combined_bench_names[bench_name]["preds"])
                    if not pd.isna(bench_score):
                        results[
                            f"{bench_split}_bench_accuracy_{bench_name}"
                        ] = bench_score
                        bench_scores.append(bench_score)
                    else:
                        results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0
                        bench_scores.append(0.0)
                results[f"{bench_split}_bench_average_accuracy"] = np.mean(bench_scores)
                results[f"{bench_split}_bench_total_accuracy"] = accuracy.compute(
                    references=bench_refs, predictions=bench_preds
                )["accuracy"]
                trainer.log(results)

            results = broadcast_dict(results)
            for key, val in results.items():
                metrics[key] = val

    return BenchEvalCallback


def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
    class CausalLMBenchEvalCallback(TrainerCallback):
        """Callback to log prediction values during each evaluation"""

        def __init__(self, cfg):
            self.cfg = cfg
            self.logged = False
            self.metrics = self.__maybe_load_metrics()

        def __maybe_load_metrics(self):
            metrics = {}
            for metric in self.cfg.eval_causal_lm_metrics:
                try:
                    metrics[metric] = evaluate.load(metric)
                except Exception as exc:  # pylint: disable=broad-exception-caught
                    LOG.warning(f"{metric}: {exc.args}")
            return metrics

        def on_evaluate(
            self,
            args: AxolotlTrainingArguments,  # pylint: disable=unused-argument
            state: TrainerState,
            control: TrainerControl,
            train_dataloader,  # pylint: disable=unused-argument
            eval_dataloader,
            **kwargs,  # pylint: disable=unused-argument
        ):
            trainer.model.eval()
            device = torch.device(self.cfg.device)

            # pylint: disable=duplicate-code
            generation_config = GenerationConfig(
                max_new_tokens=self.cfg.eval_max_new_tokens,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=False,
                use_cache=True,
                return_dict_in_generate=True,
                output_attentions=False,
                output_hidden_states=False,
                output_scores=False,
            )

            def find_ranges(lst):
                ranges = []
                start = 0
                for i in range(1, len(lst)):
                    if lst[i] == 0:
                        ranges.append((start, i - 1))
                        start = i
                end = len(lst) - 1
                ranges.append((start, end))
                return ranges

            def compute(metric: evaluate.Metric, **kwargs):
                # safely compute a metric and return the score if the format is correct
                metric_score = None
                try:
                    metric_score = metric.compute(**kwargs)
                    return (
                        metric_score["score"]
                        if "score" in metric_score
                        else metric_score["mean_score"]
                    )
                except Exception:  # pylint: disable=broad-exception-caught
                    LOG.debug(
                        f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
                    )
                return metric_score

            def evaluate_preds(sources, predictions, references):
                scores = {}

                for metric_name, metric in self.metrics.items():
                    score = compute(
                        metric,
                        references=references,
                        predictions=predictions,
                        sources=sources,
                    )
                    score = score or compute(
                        metric,
                        references=[[r] for r in references],
                        predictions=predictions,
                    )
                    scores[metric_name] = score
                return scores

            def predict_with_generate():
                eval_src, eval_pred, eval_ref = [], [], []

                for batch in tqdm(eval_dataloader):
                    batch_labels = batch["labels"].to(device)
                    batch_input_ids = batch["input_ids"].to(device)

                    if "position_ids" in batch:
                        batch_pos_ids = batch["position_ids"].tolist()
                    else:
                        batch_pos_ids = [None] * len(batch["input_ids"])

                    prompt_token_ids_list = []
                    completion_token_ids_list = []

                    for input_ids_all, labels_all, pos_ids in zip(
                        batch_input_ids,
                        batch_labels,
                        batch_pos_ids,
                    ):
                        if pos_ids is None:
                            pos_ranges = [(0, len(input_ids_all) - 1)]
                        else:
                            pos_ranges = find_ranges(pos_ids)

                        for pos_range in pos_ranges:
                            start, end = pos_range
                            if start == end:
                                continue

                            input_ids = input_ids_all[start : end + 1]
                            labels = labels_all[start : end + 1]

                            tokens_without_loss = labels == IGNORE_INDEX
                            tokens_with_loss = labels != IGNORE_INDEX
                            tokens_exclude_padding = input_ids != tokenizer.pad_token_id
                            prompt_token_includes = (
                                tokens_without_loss & tokens_exclude_padding
                            )

                            prompt_token_ids = input_ids[prompt_token_includes]
                            prompt_token_ids_list.append(prompt_token_ids)

                            completion_token_ids = input_ids[tokens_with_loss]
                            completion_token_ids_list.append(completion_token_ids)

                    prompt_texts = tokenizer.batch_decode(
                        prompt_token_ids_list, skip_special_tokens=True
                    )
                    completion_texts = tokenizer.batch_decode(
                        completion_token_ids_list, skip_special_tokens=True
                    )

                    with torch.no_grad():
                        prompt_encoding = tokenizer(
                            prompt_texts, padding=True, return_tensors="pt"
                        ).to(self.cfg.device)
                        predictions = trainer.model.generate(
                            **prompt_encoding, generation_config=generation_config
                        )

                    prediction_all_tokens = predictions["sequences"].cpu().tolist()
                    prediction_without_prompt_tokens_list = []
                    for prompt_token_ids, prediction_tokens in zip(
                        prompt_token_ids_list, prediction_all_tokens
                    ):
                        prediction_without_prompt_tokens = prediction_tokens[
                            len(prompt_token_ids) :
                        ]
                        prediction_without_prompt_tokens_list.append(
                            prediction_without_prompt_tokens
                        )

                    predicted_texts = tokenizer.batch_decode(
                        prediction_without_prompt_tokens_list, skip_special_tokens=True
                    )

                    eval_src.extend(prompt_texts)
                    eval_pred.extend(predicted_texts)
                    eval_ref.extend(completion_texts)

                return eval_src, eval_pred, eval_ref

            if is_main_process():
                eval_preds = predict_with_generate()
                trainer.log(evaluate_preds(*eval_preds))

            return control

    return CausalLMBenchEvalCallback


def log_prediction_callback_factory(trainer: Trainer, tokenizer):
    class LogPredictionCallback(TrainerCallback):
        """Callback to log prediction values during each evaluation"""

        def __init__(self, cfg):
            self.cfg = cfg
            self.logged = False

        def on_evaluate(
            self,
            args: AxolotlTrainingArguments,  # pylint: disable=unused-argument
            state: TrainerState,
            control: TrainerControl,
            train_dataloader,  # pylint: disable=unused-argument
            eval_dataloader,
            **kwargs,  # pylint: disable=unused-argument
        ):
            eval_table_size = self.cfg.eval_table_size

            if eval_table_size <= 0:
                return control

            trainer.model.eval()
            device = torch.device(self.cfg.device)

            # pylint: disable=duplicate-code
            generation_config = GenerationConfig(
                max_new_tokens=self.cfg.eval_max_new_tokens,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=False,
                use_cache=True,
                return_dict_in_generate=True,
                output_attentions=False,
                output_hidden_states=False,
                output_scores=False,
            )

            def logits_to_tokens(logits) -> torch.Tensor:
                probabilities = torch.softmax(logits, dim=-1)
                # Get the predicted token ids (the ones with the highest probability)
                predicted_token_ids = torch.argmax(probabilities, dim=-1)
                return predicted_token_ids

            def find_ranges(lst):
                ranges = []
                start = 0
                for i in range(1, len(lst)):
                    if lst[i] == 0:
                        ranges.append((start, i - 1))
                        start = i
                end = len(lst) - 1
                ranges.append((start, end))
                return ranges

            def log_table_from_dataloader(name: str, table_dataloader):
                table = wandb.Table(  # type: ignore[attr-defined]
                    columns=[
                        "id",
                        "Prompt",
                        "Correct Completion",
                        "Predicted Completion (model.generate)",
                        "Predicted Completion (trainer.prediction_step)",
                    ]
                )
                row_index = 0

                for batch in tqdm(table_dataloader):
                    if row_index > eval_table_size:
                        break

                    batch_labels = batch["labels"].to(device)
                    batch_input_ids = batch["input_ids"].to(device)

                    if "position_ids" in batch:
                        batch_pos_ids = batch["position_ids"].tolist()
                    else:
                        batch_pos_ids = [None] * len(batch["input_ids"])

                    (_, batch_logits, _) = trainer.prediction_step(
                        trainer.model,
                        batch,
                        prediction_loss_only=False,
                    )

                    prompt_token_ids_list = []
                    pred_step_token_ids_list = []
                    completion_token_ids_list = []

                    for input_ids_all, labels_all, pos_ids, logits in zip(
                        batch_input_ids,
                        batch_labels,
                        batch_pos_ids,
                        batch_logits,
                    ):
                        if pos_ids is None:
                            pos_ranges = [(0, len(input_ids_all) - 1)]
                        else:
                            pos_ranges = find_ranges(pos_ids)

                        for pos_range in pos_ranges:
                            start, end = pos_range
                            if start == end:
                                continue

                            input_ids = input_ids_all[start : end + 1]
                            labels = labels_all[start : end + 1]

                            tokens_without_loss = labels == IGNORE_INDEX
                            tokens_with_loss = labels != IGNORE_INDEX
                            tokens_exclude_padding = input_ids != tokenizer.pad_token_id
                            prompt_token_includes = (
                                tokens_without_loss & tokens_exclude_padding
                            )

                            prompt_token_ids = input_ids[prompt_token_includes]
                            prompt_token_ids_list.append(prompt_token_ids)

                            completion_token_ids = input_ids[tokens_with_loss]
                            completion_token_ids_list.append(completion_token_ids)

                            pred_step_token_ids = logits_to_tokens(
                                logits[start : end + 1]
                            )[tokens_with_loss]
                            pred_step_token_ids_list.append(pred_step_token_ids)

                    prompt_texts = tokenizer.batch_decode(
                        prompt_token_ids_list, skip_special_tokens=True
                    )
                    completion_texts = tokenizer.batch_decode(
                        completion_token_ids_list, skip_special_tokens=True
                    )
                    pred_step_texts = tokenizer.batch_decode(
                        pred_step_token_ids_list, skip_special_tokens=True
                    )

                    with torch.no_grad():
                        prompt_encoding = tokenizer(
                            prompt_texts, padding=True, return_tensors="pt"
                        ).to(self.cfg.device)
                        predictions = trainer.model.generate(
                            **prompt_encoding, generation_config=generation_config
                        )

                    prediction_all_tokens = predictions["sequences"].cpu().tolist()
                    prediction_without_prompt_tokens_list = []
                    for prompt_token_ids, prediction_tokens in zip(
                        prompt_token_ids_list, prediction_all_tokens
                    ):
                        prediction_without_prompt_tokens = prediction_tokens[
                            len(prompt_token_ids) :
                        ]
                        prediction_without_prompt_tokens_list.append(
                            prediction_without_prompt_tokens
                        )

                    predicted_texts = tokenizer.batch_decode(
                        prediction_without_prompt_tokens_list, skip_special_tokens=True
                    )

                    for (
                        prompt_text,
                        completion_text,
                        prediction_text,
                        pred_step_text,
                    ) in zip(
                        prompt_texts, completion_texts, predicted_texts, pred_step_texts
                    ):
                        table.add_data(
                            row_index,
                            prompt_text,
                            completion_text,
                            prediction_text,
                            pred_step_text,
                        )
                        row_index += 1

                wandb.run.log({f"{name} - Predictions vs Ground Truth": table})  # type: ignore[attr-defined]

            if is_main_process():
                log_table_from_dataloader("Eval", eval_dataloader)

            return control

    return LogPredictionCallback


class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
    """Callback to save axolotl config to wandb"""

    def __init__(self, axolotl_config_path):
        self.axolotl_config_path = axolotl_config_path

    def on_train_begin(
        self,
        args: AxolotlTrainingArguments,  # pylint: disable=unused-argument
        state: TrainerState,  # pylint: disable=unused-argument
        control: TrainerControl,
        **kwargs,  # pylint: disable=unused-argument
    ):
        if is_main_process():
            try:
                # sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
                with NamedTemporaryFile(
                    mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
                ) as temp_file:
                    copyfile(self.axolotl_config_path, temp_file.name)
                    wandb.save(temp_file.name)
                LOG.info(
                    "The Axolotl config has been saved to the WandB run under files."
                )
            except (FileNotFoundError, ConnectionError) as err:
                LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
        return control


class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
    """Callback to save axolotl config to mlflow"""

    def __init__(self, axolotl_config_path):
        self.axolotl_config_path = axolotl_config_path

    def on_train_begin(
        self,
        args: AxolotlTrainingArguments,  # pylint: disable=unused-argument
        state: TrainerState,  # pylint: disable=unused-argument
        control: TrainerControl,
        **kwargs,  # pylint: disable=unused-argument
    ):
        if is_main_process():
            try:
                with NamedTemporaryFile(
                    mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
                ) as temp_file:
                    copyfile(self.axolotl_config_path, temp_file.name)
                    mlflow.log_artifact(temp_file.name, artifact_path="")
                    LOG.info(
                        "The Axolotl config has been saved to the MLflow artifacts."
                    )
            except (FileNotFoundError, ConnectionError) as err:
                LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
        return control