winglian commited on
Commit
a4f1241
1 Parent(s): 48f4c05

update readme and add typehints

Browse files
Files changed (2) hide show
  1. README.md +1 -7
  2. src/axolotl/utils/data.py +8 -7
README.md CHANGED
@@ -363,13 +363,7 @@ Pass the appropriate flag to the train command:
363
 
364
  ### Merge LORA to base
365
 
366
- Add below flag to train command above (and using LoRA)
367
-
368
- ```bash
369
- --merge_lora --lora_model_dir="./completed-model"
370
- ```
371
-
372
- Add below flag to train command above (and using QLoRA)
373
 
374
  ```bash
375
  --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
 
363
 
364
  ### Merge LORA to base
365
 
366
+ Add below flag to train command above
 
 
 
 
 
 
367
 
368
  ```bash
369
  --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
src/axolotl/utils/data.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from hashlib import md5
3
  from pathlib import Path
 
4
 
5
  from datasets import (
6
  load_from_disk,
@@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets(
80
  logging.info("Loading raw datasets...")
81
  datasets = []
82
  for d in cfg.datasets:
83
- ds = None
84
  ds_from_hub = False
85
  try:
86
  load_dataset(d.path, streaming=True, use_auth_token=True)
@@ -90,32 +91,32 @@ def load_tokenized_prepared_datasets(
90
 
91
  # prefer local dataset, even if hub exists
92
  if Path(d.path).exists():
93
- ds: IterableDataset = load_dataset(
94
  "json", data_files=d.path, streaming=False, split=None
95
  )
96
  elif ds_from_hub:
97
  if d.data_files:
98
- ds = load_dataset(
99
  d.path,
100
  streaming=False,
101
  data_files=d.data_files,
102
  use_auth_token=True,
103
  )
104
  else:
105
- ds = load_dataset(d.path, streaming=False, use_auth_token=True)
106
  else:
107
  fp = hf_hub_download(
108
  repo_id=d.path, repo_type="dataset", filename=d.data_files
109
  )
110
- ds = load_dataset("json", data_files=fp, streaming=False, split=None)
111
  if not ds:
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
115
  if "train" in ds:
116
- ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
117
  else:
118
- ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
119
  d_type = d.type
120
  d_type_split = d_type.split(":")
121
  d_base_type = d_type_split[0]
 
1
  import logging
2
  from hashlib import md5
3
  from pathlib import Path
4
+ from typing import Union
5
 
6
  from datasets import (
7
  load_from_disk,
 
81
  logging.info("Loading raw datasets...")
82
  datasets = []
83
  for d in cfg.datasets:
84
+ ds: Union[Dataset, DatasetDict] = None
85
  ds_from_hub = False
86
  try:
87
  load_dataset(d.path, streaming=True, use_auth_token=True)
 
91
 
92
  # prefer local dataset, even if hub exists
93
  if Path(d.path).exists():
94
+ ds: Dataset = load_dataset(
95
  "json", data_files=d.path, streaming=False, split=None
96
  )
97
  elif ds_from_hub:
98
  if d.data_files:
99
+ ds: Dataset = load_dataset(
100
  d.path,
101
  streaming=False,
102
  data_files=d.data_files,
103
  use_auth_token=True,
104
  )
105
  else:
106
+ ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True)
107
  else:
108
  fp = hf_hub_download(
109
  repo_id=d.path, repo_type="dataset", filename=d.data_files
110
  )
111
+ ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
112
  if not ds:
113
  raise Exception("unhandled dataset load")
114
  # support for using a subset of the data
115
  if d.shards:
116
  if "train" in ds:
117
+ ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
118
  else:
119
+ ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
120
  d_type = d.type
121
  d_type_split = d_type.split(":")
122
  d_base_type = d_type_split[0]