File size: 29,782 Bytes
4c0eddb
eea2731
2bb0b78
6045345
553c80f
6045345
1a6309c
6045345
488a67d
2bb0b78
 
 
 
 
 
 
8d43785
553c80f
2809f3f
6045345
1c412c7
2ce5c0d
4ea9a66
6045345
37293dc
6045345
37293dc
6045345
37293dc
6045345
1365073
6045345
 
 
 
cf68153
37293dc
1365073
1a6309c
37293dc
ce34d64
e50ab07
6045345
553c80f
d2e7f27
fc2d6be
81d3845
2e22404
 
 
553c80f
2e22404
6045345
553a86b
2e22404
 
0b4cf5b
 
 
 
 
 
 
2e22404
e50ab07
2e22404
cb9797e
e50ab07
cb9797e
 
2e22404
553c80f
 
2ce5c0d
 
 
 
 
553c80f
2e22404
553c80f
2e22404
553c80f
 
2e22404
 
 
 
 
 
e50ab07
2e22404
797f3dd
 
 
 
 
 
 
2e22404
 
641e6f7
2e22404
 
 
641e6f7
e50ab07
553a86b
6045345
ce34d64
 
1a6309c
2809f3f
6045345
0b4cf5b
6045345
4a17a4c
cb7cd34
 
48630f5
 
 
 
 
 
392dfd9
cb7cd34
 
0b4cf5b
 
6045345
 
 
 
 
 
1d5ab84
e50ab07
1c33eb8
1d5ab84
 
ce34d64
37293dc
69fac9a
ce34d64
607a4d3
a1f9850
1d5ab84
6045345
1d5ab84
 
1e56b88
 
 
 
 
553a86b
6045345
553a86b
6045345
553a86b
 
2ce5c0d
 
 
 
2cfe9e9
 
 
 
553a86b
2cfe9e9
 
6045345
5ac3392
 
 
 
 
 
 
 
 
cb7cd34
e50ab07
a4f1241
6045345
 
37293dc
e50ab07
 
37293dc
69fac9a
37293dc
6045345
992d57f
6045345
 
3cc67d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6045345
e50ab07
9bdd30c
 
575a082
 
 
 
 
 
 
 
9bdd30c
3cc67d2
 
9bdd30c
d2e7f27
e50ab07
 
9bdd30c
 
 
 
 
 
 
6045345
88089e8
e50ab07
 
88089e8
e50ab07
69fac9a
88089e8
3cc67d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6045345
e50ab07
8fe0e63
e50ab07
8fe0e63
e50ab07
8fe0e63
e50ab07
8fe0e63
e50ab07
8fe0e63
 
e50ab07
8fe0e63
 
 
 
 
 
 
 
88089e8
e50ab07
 
 
 
 
88089e8
8d43785
392dfd9
e8aacfb
e50ab07
2e56203
2cfe9e9
e50ab07
392dfd9
2e56203
e50ab07
 
 
d2e7f27
 
e50ab07
d2e7f27
 
 
 
2e56203
 
409ca0f
 
e50ab07
 
409ca0f
e50ab07
409ca0f
 
e50ab07
aac4b76
e50ab07
 
 
 
 
 
 
 
 
 
 
 
2bb0b78
 
fe28543
2bb0b78
 
 
2ce5c0d
 
 
1c412c7
b1f4f7a
6045345
1d5ab84
553a86b
1d5ab84
 
ce34d64
 
 
6045345
e50ab07
aa3c3f9
 
3cc67d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce34d64
37293dc
 
 
1a6309c
2ce5c0d
 
aa3c3f9
 
097d367
553a86b
2bc1a5b
 
 
37293dc
 
2bc1a5b
097d367
ab5cd28
2bb0b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b4cf5b
 
2bb0b78
1a6309c
 
 
 
 
 
 
2bb0b78
ab5cd28
 
 
 
 
6045345
e50ab07
 
 
 
 
 
 
 
 
3db5f2f
 
 
 
 
e50ab07
 
 
 
 
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
cdc71f7
7570446
 
3db5f2f
cdc71f7
e50ab07
 
 
 
 
 
 
 
 
 
 
 
 
488a67d
 
2f586d1
 
 
eea2731
2f586d1
eea2731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553a86b
eea2731
488a67d
 
553c80f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c6f928
2f586d1
 
 
2ce5c0d
2f586d1
1157950
 
 
2f586d1
eea2731
553c80f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d3845
553c80f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
"""Module containing data utilities"""
import functools
import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Union

import torch
from datasets import (
    Dataset,
    DatasetDict,
    concatenate_datasets,
    load_dataset,
    load_from_disk,
)
from huggingface_hub import hf_hub_download
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
    AlpacaMultipleChoicePromptTokenizingStrategy,
    AlpacaPromptTokenizingStrategy,
    AlpacaReflectionPTStrategy,
    GPTeacherPromptTokenizingStrategy,
    JeopardyPromptTokenizingStrategy,
    OpenAssistantPromptTokenizingStrategy,
    SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
    AlpacaPrompter,
    GPTeacherPrompter,
    JeopardyPrompter,
    MultipleChoiceConcisePrompter,
    MultipleChoiceExplainPrompter,
    Prompter,
    ReflectAlpacaPrompter,
    SummarizeTLDRPrompter,
    UnsupportedPrompter,
)
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import (
    calculate_total_num_steps,
    process_datasets_for_packing,
    process_pretraining_datasets_for_packing,
)

LOG = logging.getLogger("axolotl")


def md5(to_hash: str, encoding: str = "utf-8") -> str:
    try:
        return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
    except TypeError:
        return hashlib.md5(to_hash.encode(encoding)).hexdigest()  # nosec


def prepare_dataset(cfg, tokenizer):
    prompters = []
    if not cfg.pretraining_dataset:
        with zero_first(is_main_process()):
            train_dataset, eval_dataset, prompters = load_prepare_datasets(
                tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
            )
    else:
        path = cfg.pretraining_dataset
        name = None
        if isinstance(cfg.pretraining_dataset, list) and isinstance(
            cfg.pretraining_dataset[0], dict
        ):
            path = cfg.pretraining_dataset[0]["path"]
            name = cfg.pretraining_dataset[0]["name"]

        train_dataset = load_pretraining_dataset(
            path,
            tokenizer,
            cfg,
            name=name,
            max_tokens=cfg.sequence_len,
            seed=cfg.seed or 42,
        )
        # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
        train_dataset = train_dataset.with_format("torch")
        eval_dataset = None
        return train_dataset, eval_dataset, cfg.max_steps, prompters

    if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
        total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
        if total_eval_steps == 0:
            raise ValueError(
                "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
            )

    if cfg.max_steps:
        total_num_steps = min(
            calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
        )
        LOG.info(f"Maximum number of steps set at {total_num_steps}")
    else:
        total_num_steps = calculate_total_num_steps(cfg, train_dataset)
    return train_dataset, eval_dataset, total_num_steps, prompters


def load_tokenized_prepared_datasets(
    tokenizer, cfg, default_dataset_prepared_path
) -> Tuple[DatasetDict, List[Prompter]]:
    tokenizer_name = tokenizer.__class__.__name__
    ds_hash = str(
        md5(
            (
                str(cfg.sequence_len)
                + "@"
                + "|".join(
                    sorted(
                        [
                            f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
                            for d in cfg.datasets
                        ]
                    )
                )
                + "|"
                + tokenizer_name
            )
        )
    )
    prepared_ds_path = (
        Path(cfg.dataset_prepared_path) / ds_hash
        if cfg.dataset_prepared_path
        else Path(default_dataset_prepared_path) / ds_hash
    )
    dataset = None
    prompters = []
    use_auth_token = cfg.hf_use_auth_token
    try:
        if cfg.push_dataset_to_hub:
            dataset = load_dataset(
                f"{cfg.push_dataset_to_hub}/{ds_hash}",
                token=use_auth_token,
            )
            dataset = dataset["train"]
    except Exception:  # pylint: disable=broad-except # nosec
        pass

    if dataset:
        ...
    elif (
        cfg.dataset_prepared_path
        and any(prepared_ds_path.glob("*"))
        and not cfg.is_preprocess
    ):
        LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
        dataset = load_from_disk(str(prepared_ds_path))
        LOG.info("Prepared dataset loaded from disk...")
    else:
        LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
        LOG.info("Loading raw datasets...")
        if not cfg.is_preprocess:
            LOG.warning(
                "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
            )

        if cfg.seed:
            seed = cfg.seed
        else:
            LOG.info("No seed provided, using default seed of 42")
            seed = 42

        datasets = []

        def for_d_in_datasets(dataset_configs):
            for dataset in dataset_configs:
                if dataset.name and isinstance(dataset.name, list):
                    for name in dataset.name:
                        yield DictDefault({**dataset, "name": name})
                else:
                    yield dataset

        # pylint: disable=invalid-name
        for config_dataset in for_d_in_datasets(cfg.datasets):
            ds: Union[Dataset, DatasetDict] = None
            ds_from_hub = False
            try:
                load_dataset(
                    config_dataset.path,
                    name=config_dataset.name,
                    streaming=True,
                    token=use_auth_token,
                )
                ds_from_hub = True
            except (FileNotFoundError, ConnectionError):
                pass

            ds_from_cloud = False
            storage_options = {}
            remote_file_system = None
            if config_dataset.path.startswith("s3://"):
                try:
                    import aiobotocore.session  # type: ignore
                    import s3fs  # type: ignore
                except ImportError as exc:
                    raise ImportError(
                        "s3:// paths require aiobotocore and s3fs to be installed"
                    ) from exc

                # Takes credentials from ~/.aws/credentials for default profile
                s3_session = aiobotocore.session.AioSession(profile="default")
                storage_options = {"session": s3_session}
                remote_file_system = s3fs.S3FileSystem(**storage_options)
            elif config_dataset.path.startswith(
                "gs://"
            ) or config_dataset.path.startswith("gcs://"):
                try:
                    import gcsfs  # type: ignore
                except ImportError as exc:
                    raise ImportError(
                        "gs:// or gcs:// paths require gcsfs to be installed"
                    ) from exc

                # gcsfs will use default credentials from the environment else anon
                # https://gcsfs.readthedocs.io/en/latest/#credentials
                storage_options = {"token": None}
                remote_file_system = gcsfs.GCSFileSystem(**storage_options)
            # TODO: Figure out how to get auth creds passed
            # elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
            #     try:
            #         import adlfs
            #     except ImportError as exc:
            #        raise ImportError(
            #            "adl:// or abfs:// paths require adlfs to be installed"
            #        ) from exc

            #     # Gen 1
            #     storage_options = {
            #         "tenant_id": TENANT_ID,
            #         "client_id": CLIENT_ID,
            #         "client_secret": CLIENT_SECRET,
            #     }
            #     # Gen 2
            #     storage_options = {
            #         "account_name": ACCOUNT_NAME,
            #         "account_key": ACCOUNT_KEY,
            #     }

            #     remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
            try:
                if remote_file_system and remote_file_system.exists(
                    config_dataset.path
                ):
                    ds_from_cloud = True
            except (FileNotFoundError, ConnectionError):
                pass

            # prefer local dataset, even if hub exists
            local_path = Path(config_dataset.path)
            if local_path.exists():
                if local_path.is_dir():
                    # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
                    ds = load_dataset(
                        config_dataset.path,
                        name=config_dataset.name,
                        data_files=config_dataset.data_files,
                        streaming=False,
                        split=None,
                    )
                elif local_path.is_file():
                    ds_type = get_ds_type(config_dataset)

                    ds = load_dataset(
                        ds_type,
                        name=config_dataset.name,
                        data_files=config_dataset.path,
                        streaming=False,
                        split=None,
                    )
                else:
                    raise ValueError(
                        "unhandled dataset load: local path exists, but is neither a directory or a file"
                    )
            elif ds_from_hub:
                ds = load_dataset(
                    config_dataset.path,
                    name=config_dataset.name,
                    streaming=False,
                    data_files=config_dataset.data_files,
                    token=use_auth_token,
                )
            elif ds_from_cloud and remote_file_system:
                if remote_file_system.isdir(config_dataset.path):
                    ds = load_from_disk(
                        config_dataset.path,
                        storage_options=storage_options,
                    )
                elif remote_file_system.isfile(config_dataset.path):
                    ds_type = get_ds_type(config_dataset)
                    ds = load_dataset(
                        ds_type,
                        name=config_dataset.name,
                        data_files=config_dataset.path,
                        streaming=False,
                        split=None,
                        storage_options=storage_options,
                    )
            else:
                if isinstance(config_dataset.data_files, str):
                    fp = hf_hub_download(
                        repo_id=config_dataset.path,
                        repo_type="dataset",
                        filename=config_dataset.data_files,
                    )
                elif isinstance(config_dataset.data_files, list):
                    fp = []
                    for file in config_dataset.data_files:
                        fp.append(
                            hf_hub_download(
                                repo_id=config_dataset.path,
                                repo_type="dataset",
                                filename=file,
                            )
                        )
                else:
                    raise ValueError(
                        "data_files must be either a string or list of strings"
                    )
                ds = load_dataset(
                    "json",
                    name=config_dataset.name,
                    data_files=fp,
                    streaming=False,
                    split=None,
                )
            if not ds:
                raise ValueError("unhandled dataset load")
            # support for using a subset of the data
            if config_dataset.shards:
                if "train" in ds:
                    ds = ds.shuffle(seed=seed)["train"].shard(
                        num_shards=config_dataset.shards, index=0
                    )
                else:
                    ds = ds.shuffle(seed=seed).shard(
                        num_shards=config_dataset.shards, index=0
                    )

            d_base_type = d_prompt_style = None
            d_type = config_dataset.type
            if isinstance(d_type, str):
                d_type_split = d_type.split(":")
                d_base_type = d_type_split[0]
                d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
            if "train" in ds:
                ds = ds["train"]
            elif (
                isinstance(ds, DatasetDict)
                and config_dataset.train_on_split
                and config_dataset.train_on_split in ds
            ):
                ds = ds[config_dataset.train_on_split]
            elif isinstance(ds, DatasetDict):
                raise ValueError(
                    f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
                )

            dataset_wrapper, dataset_prompter = get_dataset_wrapper(
                config_dataset=config_dataset,
                dataset=ds,
                tokenizer=tokenizer,
                cfg=cfg,
                d_base_type=d_base_type,
                d_prompt_style=d_prompt_style,
            )
            datasets.append(dataset_wrapper)
            prompters.append(dataset_prompter)

        LOG.info("merging datasets")
        dataset = concatenate_datasets(datasets)

        if len(datasets) > 1:
            LOG.info("shuffle merged datasets")
            dataset = dataset.shuffle(seed=seed)

        dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)

        if cfg.local_rank == 0:
            LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
            dataset.save_to_disk(prepared_ds_path)
            if cfg.push_dataset_to_hub:
                LOG.info(
                    f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
                )
                dataset.push_to_hub(
                    f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
                )

    return dataset, prompters


def get_ds_type(config_dataset: DictDefault):
    """
    Get the dataset type from the path if it's not specified
    """
    ds_type = "json"
    if config_dataset.ds_type:
        ds_type = config_dataset.ds_type
    elif ".parquet" in config_dataset.path:
        ds_type = "parquet"
    elif ".arrow" in config_dataset.path:
        ds_type = "arrow"
    elif ".csv" in config_dataset.path:
        ds_type = "csv"
    elif ".txt" in config_dataset.path:
        ds_type = "text"
    return ds_type


def load_prepare_datasets(
    tokenizer: PreTrainedTokenizerBase,
    cfg,
    default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
    dataset, prompters = load_tokenized_prepared_datasets(
        tokenizer, cfg, default_dataset_prepared_path
    )

    if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
        LOG.info(
            f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
        )
        dataset = dataset.shard(
            num_shards=cfg.dataset_shard_num,
            index=cfg.dataset_shard_idx,
        )

    if cfg.val_set_size:
        # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
        to_hash_train = (
            dataset._fingerprint  # pylint: disable=protected-access
            + "|"
            + str(cfg.val_set_size)
            + "|"
            + "train"
            + "|"
            + str(cfg.seed or 42)
        )
        to_hash_test = (
            dataset._fingerprint  # pylint: disable=protected-access
            + "|"
            + str(cfg.val_set_size)
            + "|"
            + "test"
            + "|"
            + str(cfg.seed or 42)
        )
        train_fingerprint = md5(to_hash_train)
        test_fingerprint = md5(to_hash_test)

        dataset = dataset.train_test_split(
            test_size=cfg.val_set_size,
            shuffle=False,
            seed=cfg.seed or 42,
            train_new_fingerprint=train_fingerprint,
            test_new_fingerprint=test_fingerprint,
        )

        train_dataset = dataset["train"]
        eval_dataset = dataset["test"]
    else:
        train_dataset = dataset
        eval_dataset = None

    return train_dataset, eval_dataset, prompters


def get_dataset_wrapper(
    config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
    dataset_wrapper = None
    dataset_prompter = None

    ds_kwargs = {
        "process_count": cfg.dataset_processes,
        "keep_in_memory": cfg.dataset_keep_in_memory is True,
    }

    if (
        "input_ids" in dataset.features
        and "attention_mask" in dataset.features
        and "labels" in dataset.features
    ):
        # dataset is already tokenized, just drop it straight in
        dataset_prompter = UnsupportedPrompter()
        dataset_wrapper = dataset
    elif isinstance(config_dataset.type, DictDefault):
        ds_strategy = load(
            "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
        )
        dataset_prompter = UnsupportedPrompter()
        dataset_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
    elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
        dataset_prompter = UnsupportedPrompter()
        dataset_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
    elif d_base_type == "alpaca":
        dataset_prompter = AlpacaPrompter(d_prompt_style)
        ds_strategy = AlpacaPromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "explainchoice":
        dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
        ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "concisechoice":
        dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
        ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "summarizetldr":
        dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
        ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "jeopardy":
        dataset_prompter = JeopardyPrompter(d_prompt_style)
        ds_strategy = JeopardyPromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "oasst":
        dataset_prompter = AlpacaPrompter(d_prompt_style)
        ds_strategy = OpenAssistantPromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "gpteacher":
        dataset_prompter = GPTeacherPrompter(d_prompt_style)
        ds_strategy = GPTeacherPromptTokenizingStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    elif d_base_type == "reflection":
        dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
        ds_strategy = AlpacaReflectionPTStrategy(
            dataset_prompter,
            tokenizer,
            cfg.train_on_inputs,
            cfg.sequence_len,
        )
        ds_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )
        dataset_wrapper = ds_wrapper
    else:
        suffix = ""
        if ":load_" in config_dataset.type:
            suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
        LOG.error(
            f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
        )
        raise ValueError(
            f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
        )

    return dataset_wrapper, dataset_prompter


def encode_pretraining(
    tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
    res = tokenizer(
        examples,
        truncation=True,
        max_length=max_tokens - 2,
        add_special_tokens=True,
    )
    # Convert to PyTorch tensors
    input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
    attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
    new_input_ids = []
    new_attention_mask = []
    # Append EOS and PAD tokens to input_ids, and correct attention_mask
    for i, _ in enumerate(input_ids):
        input_ids[i] = torch.cat(
            (
                input_ids[i],
                torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
            ),
            dim=0,
        )
        attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)

    # Concatenate tokens so that their lengths are less than max_tokens
    buffer_input_ids = torch.tensor([], dtype=torch.long)
    buffer_attention_mask = torch.tensor([], dtype=torch.long)

    for ids, mask in zip(input_ids, attention_mask):
        if buffer_input_ids.numel() == max_tokens:
            new_input_ids.append(buffer_input_ids)
            new_attention_mask.append(buffer_attention_mask)
            buffer_input_ids = torch.tensor([], dtype=torch.long)
            buffer_attention_mask = torch.tensor([], dtype=torch.long)
            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
        elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
        else:
            buffer_input_ids = torch.cat(
                (
                    buffer_input_ids,
                    torch.full(
                        (max_tokens - buffer_input_ids.numel(),),
                        tokenizer.pad_token_id,
                        dtype=torch.long,
                    ),
                ),
                dim=0,
            )
            buffer_attention_mask = torch.cat(
                (
                    buffer_attention_mask,
                    torch.full(
                        (max_tokens - buffer_attention_mask.numel(),),
                        0,
                        dtype=torch.long,
                    ),
                ),
                dim=0,
            )
            new_input_ids.append(buffer_input_ids)
            new_attention_mask.append(buffer_attention_mask)
            buffer_input_ids = torch.tensor([], dtype=torch.long)
            buffer_attention_mask = torch.tensor([], dtype=torch.long)

            buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
            buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)

    if buffer_input_ids.numel() > 0:  # for any leftover tokens
        while buffer_input_ids.numel() < max_tokens:  # make all sequences equal in size
            buffer_input_ids = torch.cat(
                (
                    buffer_input_ids,
                    torch.full(
                        (max_tokens - buffer_input_ids.numel(),),
                        tokenizer.pad_token_id,
                        dtype=torch.long,
                    ),
                ),
                dim=0,
            )
            buffer_attention_mask = torch.cat(
                (
                    buffer_attention_mask,
                    torch.full(
                        (max_tokens - buffer_attention_mask.numel(),),
                        0,
                        dtype=torch.long,
                    ),
                ),
                dim=0,
            )
        new_input_ids.append(buffer_input_ids)
        new_attention_mask.append(buffer_attention_mask)

    ret = {
        "input_ids": [seq.tolist() for seq in new_input_ids],
        "labels": [seq.tolist() for seq in new_input_ids],
        "attention_mask": [seq.tolist() for seq in new_attention_mask],
    }

    LOG.debug(len(ret["input_ids"]))
    return ret


def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
    if cfg.sample_packing:
        collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
            tokenizer,
            return_tensors="pt",
            padding=True,
            pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
        )
        encode = functools.partial(
            encode_packed_pretraining,
            tokenizer,
            collate_fn,
            max_seq_length=max_tokens,
            batch_size=cfg.micro_batch_size,
        )
        # set this to 1 so downstream data_loader doesn't try to increase the batch again
        cfg.micro_batch_size = 1
    else:
        encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

    dataset = load_dataset(path, streaming=True, split="train", name=name)
    dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
    dataset = dataset.map(
        encode,
        batched=True,
        batch_size=10_000,
        input_columns="text",
        # remove all the existing columns after mapping since they end up having
        # a different length than the encoded/tokenized column
        remove_columns=dataset.features.keys(),
    )
    return dataset


def encode_packed_pretraining(
    tokenizer: PreTrainedTokenizerBase,
    collate_fn,
    examples: List[str],
    max_seq_length: int = 2048,
    batch_size: int = 4,
) -> Dict[str, List]:
    # pylint: disable=duplicate-code
    # tokenize all the examples
    # rows get split with stride (overlap)
    res = tokenizer(
        examples,
        truncation=True,
        max_length=max_seq_length - 1,
        add_special_tokens=True,
        return_overflowing_tokens=True,
        stride=256,
    )

    input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
    attention_mask = [seq + [1] for seq in res["attention_mask"]]

    tokenized_examples = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }

    train_dataset = Dataset.from_dict(tokenized_examples)
    train_dataset = process_pretraining_datasets_for_packing(
        train_dataset, max_seq_length
    )

    sampler = MultipackBatchSampler(
        RandomSampler(train_dataset),
        batch_size=batch_size,
        drop_last=True,
        batch_max_len=batch_size * max_seq_length,
        lengths=get_dataset_lengths(train_dataset),
    )

    chunked_data = defaultdict(list)

    for data in sampler:
        features = train_dataset[data]
        features["labels"] = features["input_ids"].copy()
        collated_features = collate_fn(features)

        for feature in features.keys():
            if feature == "length":
                continue
            chunked_data[feature].append(collated_features[feature].squeeze(0))

    return chunked_data