Merge pull request #293 from NanoCode012/fix/tokenize-speed
Browse files- src/axolotl/datasets.py +11 -12
src/axolotl/datasets.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
"""Module containing Dataset functionality"""
|
2 |
|
3 |
import logging
|
|
|
4 |
from typing import List
|
5 |
|
6 |
import torch
|
7 |
from datasets import IterableDataset
|
8 |
|
9 |
-
from .prompt_tokenizers import
|
10 |
|
11 |
# We want this to be a wrapper for an existing dataset that we have loaded
|
12 |
# lets use the concept of middlewares to wrap each dataset, for example
|
@@ -34,17 +35,15 @@ class TokenizedPromptDataset(IterableDataset):
|
|
34 |
self.dataset = dataset
|
35 |
|
36 |
def __iter__(self):
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
if count == 0:
|
47 |
-
raise RuntimeError("Expected at least one datapoint in dataset.")
|
48 |
|
49 |
|
50 |
# TODO this isn't the best since it can't interleave datasets
|
|
|
1 |
"""Module containing Dataset functionality"""
|
2 |
|
3 |
import logging
|
4 |
+
import os
|
5 |
from typing import List
|
6 |
|
7 |
import torch
|
8 |
from datasets import IterableDataset
|
9 |
|
10 |
+
from .prompt_tokenizers import PromptTokenizingStrategy
|
11 |
|
12 |
# We want this to be a wrapper for an existing dataset that we have loaded
|
13 |
# lets use the concept of middlewares to wrap each dataset, for example
|
|
|
35 |
self.dataset = dataset
|
36 |
|
37 |
def __iter__(self):
|
38 |
+
features = self.dataset.features.keys()
|
39 |
+
num_proc = os.cpu_count()
|
40 |
+
return iter(
|
41 |
+
self.dataset.map(
|
42 |
+
self.prompt_tokenizer.tokenize_prompt,
|
43 |
+
num_proc=num_proc,
|
44 |
+
remove_columns=features,
|
45 |
+
)
|
46 |
+
)
|
|
|
|
|
47 |
|
48 |
|
49 |
# TODO this isn't the best since it can't interleave datasets
|