Jan Philipp Harries
Jan Philipp Harries
commited on
Fix pretraining with iterable/streaming Dataset (#556)
Browse files* return without packing prep/len
* fix remove columns
* fix encode arguments
* add error when max steps not set
* fix test
---------
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
- src/axolotl/utils/config.py +4 -0
- src/axolotl/utils/data.py +14 -5
- tests/test_data.py +1 -1
src/axolotl/utils/config.py
CHANGED
|
@@ -191,6 +191,10 @@ def validate_config(cfg):
|
|
| 191 |
LOG.warning(
|
| 192 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
| 193 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
| 196 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
|
|
|
| 191 |
LOG.warning(
|
| 192 |
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
| 193 |
)
|
| 194 |
+
if cfg.pretraining_dataset and not cfg.max_steps:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
| 197 |
+
)
|
| 198 |
|
| 199 |
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
| 200 |
not cfg.optimizer or "adamw" not in cfg.optimizer
|
src/axolotl/utils/data.py
CHANGED
|
@@ -3,7 +3,7 @@ import functools
|
|
| 3 |
import hashlib
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Tuple, Union
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from datasets import (
|
|
@@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
|
|
| 74 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 75 |
train_dataset = train_dataset.with_format("torch")
|
| 76 |
eval_dataset = None
|
|
|
|
| 77 |
|
| 78 |
with zero_first(is_main_process()):
|
| 79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
@@ -527,9 +528,11 @@ def load_prepare_datasets(
|
|
| 527 |
return train_dataset, eval_dataset
|
| 528 |
|
| 529 |
|
| 530 |
-
def encode_pretraining(
|
|
|
|
|
|
|
| 531 |
res = tokenizer(
|
| 532 |
-
examples
|
| 533 |
truncation=True,
|
| 534 |
max_length=max_tokens - 2,
|
| 535 |
add_special_tokens=True,
|
|
@@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
|
| 637 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 638 |
dataset = load_dataset(path, streaming=True, split="train")
|
| 639 |
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
return dataset
|
|
|
|
| 3 |
import hashlib
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Dict, List, Tuple, Union
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from datasets import (
|
|
|
|
| 74 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
| 75 |
train_dataset = train_dataset.with_format("torch")
|
| 76 |
eval_dataset = None
|
| 77 |
+
return train_dataset, eval_dataset, cfg.max_steps
|
| 78 |
|
| 79 |
with zero_first(is_main_process()):
|
| 80 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
|
|
| 528 |
return train_dataset, eval_dataset
|
| 529 |
|
| 530 |
|
| 531 |
+
def encode_pretraining(
|
| 532 |
+
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
| 533 |
+
) -> Dict[str, List]:
|
| 534 |
res = tokenizer(
|
| 535 |
+
examples,
|
| 536 |
truncation=True,
|
| 537 |
max_length=max_tokens - 2,
|
| 538 |
add_special_tokens=True,
|
|
|
|
| 640 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
| 641 |
dataset = load_dataset(path, streaming=True, split="train")
|
| 642 |
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
| 643 |
+
dataset = dataset.map(
|
| 644 |
+
encode,
|
| 645 |
+
batched=True,
|
| 646 |
+
input_columns="text",
|
| 647 |
+
remove_columns=[
|
| 648 |
+
"text",
|
| 649 |
+
],
|
| 650 |
+
)
|
| 651 |
return dataset
|
tests/test_data.py
CHANGED
|
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
|
| 35 |
"hello, hello",
|
| 36 |
]
|
| 37 |
}
|
| 38 |
-
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
| 39 |
|
| 40 |
self.assertEqual(len(result["input_ids"]), 3)
|
| 41 |
|
|
|
|
| 35 |
"hello, hello",
|
| 36 |
]
|
| 37 |
}
|
| 38 |
+
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
| 39 |
|
| 40 |
self.assertEqual(len(result["input_ids"]), 3)
|
| 41 |
|