Jan Philipp Harries
		
		Jan Philipp Harries
		
	commited on
		
		
					Fix pretraining with iterable/streaming Dataset (#556)
Browse files* return without packing prep/len
* fix remove columns
* fix encode arguments
* add error when max steps not set
* fix test
---------
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
- src/axolotl/utils/config.py +4 -0
- src/axolotl/utils/data.py +14 -5
- tests/test_data.py +1 -1
    	
        src/axolotl/utils/config.py
    CHANGED
    
    | @@ -191,6 +191,10 @@ def validate_config(cfg): | |
| 191 | 
             
                    LOG.warning(
         | 
| 192 | 
             
                        "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
         | 
| 193 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
| 194 |  | 
| 195 | 
             
                if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
         | 
| 196 | 
             
                    not cfg.optimizer or "adamw" not in cfg.optimizer
         | 
|  | |
| 191 | 
             
                    LOG.warning(
         | 
| 192 | 
             
                        "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
         | 
| 193 | 
             
                    )
         | 
| 194 | 
            +
                if cfg.pretraining_dataset and not cfg.max_steps:
         | 
| 195 | 
            +
                    raise ValueError(
         | 
| 196 | 
            +
                        "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
         | 
| 197 | 
            +
                    )
         | 
| 198 |  | 
| 199 | 
             
                if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
         | 
| 200 | 
             
                    not cfg.optimizer or "adamw" not in cfg.optimizer
         | 
    	
        src/axolotl/utils/data.py
    CHANGED
    
    | @@ -3,7 +3,7 @@ import functools | |
| 3 | 
             
            import hashlib
         | 
| 4 | 
             
            import logging
         | 
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
            -
            from typing import Tuple, Union
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            from datasets import (
         | 
| @@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer): | |
| 74 | 
             
                    # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         | 
| 75 | 
             
                    train_dataset = train_dataset.with_format("torch")
         | 
| 76 | 
             
                    eval_dataset = None
         | 
|  | |
| 77 |  | 
| 78 | 
             
                with zero_first(is_main_process()):
         | 
| 79 | 
             
                    train_dataset, eval_dataset = process_datasets_for_packing(
         | 
| @@ -527,9 +528,11 @@ def load_prepare_datasets( | |
| 527 | 
             
                return train_dataset, eval_dataset
         | 
| 528 |  | 
| 529 |  | 
| 530 | 
            -
            def encode_pretraining( | 
|  | |
|  | |
| 531 | 
             
                res = tokenizer(
         | 
| 532 | 
            -
                    examples | 
| 533 | 
             
                    truncation=True,
         | 
| 534 | 
             
                    max_length=max_tokens - 2,
         | 
| 535 | 
             
                    add_special_tokens=True,
         | 
| @@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): | |
| 637 | 
             
                encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
         | 
| 638 | 
             
                dataset = load_dataset(path, streaming=True, split="train")
         | 
| 639 | 
             
                dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
         | 
| 640 | 
            -
                 | 
| 641 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 642 | 
             
                return dataset
         | 
|  | |
| 3 | 
             
            import hashlib
         | 
| 4 | 
             
            import logging
         | 
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
            +
            from typing import Dict, List, Tuple, Union
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            from datasets import (
         | 
|  | |
| 74 | 
             
                    # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         | 
| 75 | 
             
                    train_dataset = train_dataset.with_format("torch")
         | 
| 76 | 
             
                    eval_dataset = None
         | 
| 77 | 
            +
                    return train_dataset, eval_dataset, cfg.max_steps
         | 
| 78 |  | 
| 79 | 
             
                with zero_first(is_main_process()):
         | 
| 80 | 
             
                    train_dataset, eval_dataset = process_datasets_for_packing(
         | 
|  | |
| 528 | 
             
                return train_dataset, eval_dataset
         | 
| 529 |  | 
| 530 |  | 
| 531 | 
            +
            def encode_pretraining(
         | 
| 532 | 
            +
                tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
         | 
| 533 | 
            +
            ) -> Dict[str, List]:
         | 
| 534 | 
             
                res = tokenizer(
         | 
| 535 | 
            +
                    examples,
         | 
| 536 | 
             
                    truncation=True,
         | 
| 537 | 
             
                    max_length=max_tokens - 2,
         | 
| 538 | 
             
                    add_special_tokens=True,
         | 
|  | |
| 640 | 
             
                encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
         | 
| 641 | 
             
                dataset = load_dataset(path, streaming=True, split="train")
         | 
| 642 | 
             
                dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
         | 
| 643 | 
            +
                dataset = dataset.map(
         | 
| 644 | 
            +
                    encode,
         | 
| 645 | 
            +
                    batched=True,
         | 
| 646 | 
            +
                    input_columns="text",
         | 
| 647 | 
            +
                    remove_columns=[
         | 
| 648 | 
            +
                        "text",
         | 
| 649 | 
            +
                    ],
         | 
| 650 | 
            +
                )
         | 
| 651 | 
             
                return dataset
         | 
    	
        tests/test_data.py
    CHANGED
    
    | @@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase): | |
| 35 | 
             
                            "hello, hello",
         | 
| 36 | 
             
                        ]
         | 
| 37 | 
             
                    }
         | 
| 38 | 
            -
                    result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
         | 
| 39 |  | 
| 40 | 
             
                    self.assertEqual(len(result["input_ids"]), 3)
         | 
| 41 |  | 
|  | |
| 35 | 
             
                            "hello, hello",
         | 
| 36 | 
             
                        ]
         | 
| 37 | 
             
                    }
         | 
| 38 | 
            +
                    result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
         | 
| 39 |  | 
| 40 | 
             
                    self.assertEqual(len(result["input_ids"]), 3)
         | 
| 41 |  |