fozziethebeat commited on
Commit
e634118
1 Parent(s): 02af082

Support loading datasets saved via save_to_disk (#1432)

Browse files

* Support loading datasetes saved via save_to_disk

* Adding comprehensive unittests

* Fix dataset tests due to new hash changes

Files changed (2) hide show
  1. src/axolotl/utils/data.py +13 -9
  2. tests/test_datasets.py +272 -0
src/axolotl/utils/data.py CHANGED
@@ -1,4 +1,5 @@
1
  """Module containing data utilities"""
 
2
  import functools
3
  import hashlib
4
  import logging
@@ -223,7 +224,7 @@ def load_tokenized_prepared_datasets(
223
  token=use_auth_token,
224
  )
225
  ds_from_hub = True
226
- except (FileNotFoundError, ConnectionError, HFValidationError):
227
  pass
228
 
229
  ds_from_cloud = False
@@ -290,14 +291,17 @@ def load_tokenized_prepared_datasets(
290
  local_path = Path(config_dataset.path)
291
  if local_path.exists():
292
  if local_path.is_dir():
293
- # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
294
- ds = load_dataset(
295
- config_dataset.path,
296
- name=config_dataset.name,
297
- data_files=config_dataset.data_files,
298
- streaming=False,
299
- split=None,
300
- )
 
 
 
301
  elif local_path.is_file():
302
  ds_type = get_ds_type(config_dataset)
303
 
 
1
  """Module containing data utilities"""
2
+
3
  import functools
4
  import hashlib
5
  import logging
 
224
  token=use_auth_token,
225
  )
226
  ds_from_hub = True
227
+ except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
228
  pass
229
 
230
  ds_from_cloud = False
 
291
  local_path = Path(config_dataset.path)
292
  if local_path.exists():
293
  if local_path.is_dir():
294
+ if config_dataset.data_files:
295
+ ds_type = get_ds_type(config_dataset)
296
+ ds = load_dataset(
297
+ ds_type,
298
+ name=config_dataset.name,
299
+ data_files=config_dataset.data_files,
300
+ streaming=False,
301
+ split=None,
302
+ )
303
+ else:
304
+ ds = load_from_disk(config_dataset.path)
305
  elif local_path.is_file():
306
  ds_type = get_ds_type(config_dataset)
307
 
tests/test_datasets.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test dataset loading under various conditions.
3
+ """
4
+
5
+ import shutil
6
+ import tempfile
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from datasets import Dataset
11
+ from huggingface_hub import snapshot_download
12
+ from transformers import AutoTokenizer
13
+
14
+ from axolotl.utils.data import load_tokenized_prepared_datasets
15
+ from axolotl.utils.dict import DictDefault
16
+
17
+
18
+ class TestDatasetPreparation(unittest.TestCase):
19
+ """Test a configured dataloader."""
20
+
21
+ def setUp(self) -> None:
22
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
23
+ self.tokenizer.add_special_tokens(
24
+ {
25
+ "bos_token": "<s>",
26
+ "eos_token": "</s>",
27
+ "unk_token": "<unk>",
28
+ }
29
+ )
30
+ # Alpaca dataset.
31
+ self.dataset = Dataset.from_list(
32
+ [
33
+ {
34
+ "instruction": "Evaluate this sentence for spelling and grammar mistakes",
35
+ "input": "He finnished his meal and left the resturant",
36
+ "output": "He finished his meal and left the restaurant.",
37
+ }
38
+ ]
39
+ )
40
+
41
+ def test_load_hub(self):
42
+ """Core use case. Verify that processing data from the hub works"""
43
+ with tempfile.TemporaryDirectory() as tmp_dir:
44
+ prepared_path = Path(tmp_dir) / "prepared"
45
+ cfg = DictDefault(
46
+ {
47
+ "tokenizer_config": "huggyllama/llama-7b",
48
+ "sequence_len": 1024,
49
+ "datasets": [
50
+ {
51
+ "path": "mhenrichsen/alpaca_2k_test",
52
+ "type": "alpaca",
53
+ },
54
+ ],
55
+ }
56
+ )
57
+
58
+ dataset, _ = load_tokenized_prepared_datasets(
59
+ self.tokenizer, cfg, prepared_path
60
+ )
61
+
62
+ assert len(dataset) == 2000
63
+ assert "input_ids" in dataset.features
64
+ assert "attention_mask" in dataset.features
65
+ assert "labels" in dataset.features
66
+
67
+ def test_load_local_hub(self):
68
+ """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
69
+ with tempfile.TemporaryDirectory() as tmp_dir:
70
+ tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
71
+ tmp_ds_path.mkdir(parents=True, exist_ok=True)
72
+ snapshot_download(
73
+ repo_id="mhenrichsen/alpaca_2k_test",
74
+ repo_type="dataset",
75
+ local_dir=tmp_ds_path,
76
+ )
77
+
78
+ prepared_path = Path(tmp_dir) / "prepared"
79
+ # Right now a local copy that doesn't fully conform to a dataset
80
+ # must list data_files and ds_type otherwise the loader won't know
81
+ # how to load it.
82
+ cfg = DictDefault(
83
+ {
84
+ "tokenizer_config": "huggyllama/llama-7b",
85
+ "sequence_len": 1024,
86
+ "datasets": [
87
+ {
88
+ "path": "mhenrichsen/alpaca_2k_test",
89
+ "ds_type": "parquet",
90
+ "type": "alpaca",
91
+ "data_files": [
92
+ "mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
93
+ ],
94
+ },
95
+ ],
96
+ }
97
+ )
98
+
99
+ dataset, _ = load_tokenized_prepared_datasets(
100
+ self.tokenizer, cfg, prepared_path
101
+ )
102
+
103
+ assert len(dataset) == 2000
104
+ assert "input_ids" in dataset.features
105
+ assert "attention_mask" in dataset.features
106
+ assert "labels" in dataset.features
107
+ shutil.rmtree(tmp_ds_path)
108
+
109
+ def test_load_from_save_to_disk(self):
110
+ """Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
111
+ with tempfile.TemporaryDirectory() as tmp_dir:
112
+ tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
113
+ self.dataset.save_to_disk(tmp_ds_name)
114
+
115
+ prepared_path = Path(tmp_dir) / "prepared"
116
+ cfg = DictDefault(
117
+ {
118
+ "tokenizer_config": "huggyllama/llama-7b",
119
+ "sequence_len": 256,
120
+ "datasets": [
121
+ {
122
+ "path": str(tmp_ds_name),
123
+ "type": "alpaca",
124
+ },
125
+ ],
126
+ }
127
+ )
128
+
129
+ dataset, _ = load_tokenized_prepared_datasets(
130
+ self.tokenizer, cfg, prepared_path
131
+ )
132
+
133
+ assert len(dataset) == 1
134
+ assert "input_ids" in dataset.features
135
+ assert "attention_mask" in dataset.features
136
+ assert "labels" in dataset.features
137
+
138
+ def test_load_from_dir_of_parquet(self):
139
+ """Usual use case. Verify a directory of parquet files can be loaded."""
140
+ with tempfile.TemporaryDirectory() as tmp_dir:
141
+ tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
142
+ tmp_ds_dir.mkdir()
143
+ tmp_ds_path = tmp_ds_dir / "shard1.parquet"
144
+ self.dataset.to_parquet(tmp_ds_path)
145
+
146
+ prepared_path: Path = Path(tmp_dir) / "prepared"
147
+ cfg = DictDefault(
148
+ {
149
+ "tokenizer_config": "huggyllama/llama-7b",
150
+ "sequence_len": 256,
151
+ "datasets": [
152
+ {
153
+ "path": str(tmp_ds_dir),
154
+ "ds_type": "parquet",
155
+ "name": "test_data",
156
+ "data_files": [
157
+ str(tmp_ds_path),
158
+ ],
159
+ "type": "alpaca",
160
+ },
161
+ ],
162
+ }
163
+ )
164
+
165
+ dataset, _ = load_tokenized_prepared_datasets(
166
+ self.tokenizer, cfg, prepared_path
167
+ )
168
+
169
+ assert len(dataset) == 1
170
+ assert "input_ids" in dataset.features
171
+ assert "attention_mask" in dataset.features
172
+ assert "labels" in dataset.features
173
+
174
+ def test_load_from_dir_of_json(self):
175
+ """Standard use case. Verify a directory of json files can be loaded."""
176
+ with tempfile.TemporaryDirectory() as tmp_dir:
177
+ tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
178
+ tmp_ds_dir.mkdir()
179
+ tmp_ds_path = tmp_ds_dir / "shard1.json"
180
+ self.dataset.to_json(tmp_ds_path)
181
+
182
+ prepared_path: Path = Path(tmp_dir) / "prepared"
183
+ cfg = DictDefault(
184
+ {
185
+ "tokenizer_config": "huggyllama/llama-7b",
186
+ "sequence_len": 256,
187
+ "datasets": [
188
+ {
189
+ "path": str(tmp_ds_dir),
190
+ "ds_type": "json",
191
+ "name": "test_data",
192
+ "data_files": [
193
+ str(tmp_ds_path),
194
+ ],
195
+ "type": "alpaca",
196
+ },
197
+ ],
198
+ }
199
+ )
200
+
201
+ dataset, _ = load_tokenized_prepared_datasets(
202
+ self.tokenizer, cfg, prepared_path
203
+ )
204
+
205
+ assert len(dataset) == 1
206
+ assert "input_ids" in dataset.features
207
+ assert "attention_mask" in dataset.features
208
+ assert "labels" in dataset.features
209
+
210
+ def test_load_from_single_parquet(self):
211
+ """Standard use case. Verify a single parquet file can be loaded."""
212
+ with tempfile.TemporaryDirectory() as tmp_dir:
213
+ tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
214
+ self.dataset.to_parquet(tmp_ds_path)
215
+
216
+ prepared_path: Path = Path(tmp_dir) / "prepared"
217
+ cfg = DictDefault(
218
+ {
219
+ "tokenizer_config": "huggyllama/llama-7b",
220
+ "sequence_len": 256,
221
+ "datasets": [
222
+ {
223
+ "path": str(tmp_ds_path),
224
+ "name": "test_data",
225
+ "type": "alpaca",
226
+ },
227
+ ],
228
+ }
229
+ )
230
+
231
+ dataset, _ = load_tokenized_prepared_datasets(
232
+ self.tokenizer, cfg, prepared_path
233
+ )
234
+
235
+ assert len(dataset) == 1
236
+ assert "input_ids" in dataset.features
237
+ assert "attention_mask" in dataset.features
238
+ assert "labels" in dataset.features
239
+
240
+ def test_load_from_single_json(self):
241
+ """Standard use case. Verify a single json file can be loaded."""
242
+ with tempfile.TemporaryDirectory() as tmp_dir:
243
+ tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
244
+ self.dataset.to_json(tmp_ds_path)
245
+
246
+ prepared_path: Path = Path(tmp_dir) / "prepared"
247
+ cfg = DictDefault(
248
+ {
249
+ "tokenizer_config": "huggyllama/llama-7b",
250
+ "sequence_len": 256,
251
+ "datasets": [
252
+ {
253
+ "path": str(tmp_ds_path),
254
+ "name": "test_data",
255
+ "type": "alpaca",
256
+ },
257
+ ],
258
+ }
259
+ )
260
+
261
+ dataset, _ = load_tokenized_prepared_datasets(
262
+ self.tokenizer, cfg, prepared_path
263
+ )
264
+
265
+ assert len(dataset) == 1
266
+ assert "input_ids" in dataset.features
267
+ assert "attention_mask" in dataset.features
268
+ assert "labels" in dataset.features
269
+
270
+
271
+ if __name__ == "__main__":
272
+ unittest.main()