winglian commited on
Commit
8d43785
1 Parent(s): 8e2a560

fix sharegpt handling from hf, don't worry about loading llama if using earlier transformers release

Browse files
configs/llama_65B_alpaca.yml CHANGED
@@ -5,7 +5,8 @@ load_in_8bit: true
5
  datasets:
6
  - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
 
9
  type: sharegpt
10
  - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
  type: gpteacher
@@ -30,6 +31,8 @@ wandb_log_model: checkpoint
30
  output_dir: ./lora-llama-alpaca
31
  batch_size: 128
32
  micro_batch_size: 16
 
 
33
  num_epochs: 5
34
  learning_rate: 0.00003
35
  train_on_inputs: false
 
5
  datasets:
6
  - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
+ - path: anon8231489123/ShareGPT_Vicuna_unfiltered
9
+ data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
10
  type: sharegpt
11
  - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
12
  type: gpteacher
 
31
  output_dir: ./lora-llama-alpaca
32
  batch_size: 128
33
  micro_batch_size: 16
34
+ warmup_steps: 1000
35
+ save_steps:
36
  num_epochs: 5
37
  learning_rate: 0.00003
38
  train_on_inputs: false
src/axolotl/prompters.py CHANGED
@@ -128,6 +128,10 @@ conv_vicuna_v1_1 = Conversation(
128
 
129
  class ShareGPTPrompter:
130
  def build_prompt(self, source, tokenizer):
 
 
 
 
131
  if len(source) < 2:
132
  # If there isn't a back and forth conversation, ignore it
133
  # also happens on the data splitting leaving empty conversations
 
128
 
129
  class ShareGPTPrompter:
130
  def build_prompt(self, source, tokenizer):
131
+ # ignore the system prompt if provided
132
+ if source[0]["from"] == "system":
133
+ source.pop(0)
134
+
135
  if len(source) < 2:
136
  # If there isn't a back and forth conversation, ignore it
137
  # also happens on the data splitting leaving empty conversations
src/axolotl/utils/data.py CHANGED
@@ -3,6 +3,7 @@ from hashlib import md5
3
  from pathlib import Path
4
 
5
  from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
 
6
 
7
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
8
  from axolotl.prompt_tokenizers import (
@@ -50,6 +51,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
50
  logging.info("Loading raw datasets...")
51
  datasets = []
52
  for d in cfg.datasets:
 
53
  ds_from_hub = False
54
  try:
55
  load_dataset(d.path, streaming=True)
@@ -63,9 +65,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
63
  "json", data_files=d.path, streaming=True, split=None
64
  )
65
  elif ds_from_hub:
66
- ds = load_dataset(d.path, streaming=True)
 
 
 
67
  else:
68
- raise Exception(f"unhandled dataset load for {d.path}")
 
 
 
69
 
70
  if d.type == "alpaca":
71
  ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -111,6 +119,8 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
111
  seq_length=max_packed_sequence_len,
112
  )
113
  logging.info("merging, packing, shuffling, and splitting master dataset")
 
 
114
  dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
115
  test_size=cfg.val_set_size, shuffle=True, seed=42
116
  )
 
3
  from pathlib import Path
4
 
5
  from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
6
+ from huggingface_hub import hf_hub_download
7
 
8
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
9
  from axolotl.prompt_tokenizers import (
 
51
  logging.info("Loading raw datasets...")
52
  datasets = []
53
  for d in cfg.datasets:
54
+ ds = None
55
  ds_from_hub = False
56
  try:
57
  load_dataset(d.path, streaming=True)
 
65
  "json", data_files=d.path, streaming=True, split=None
66
  )
67
  elif ds_from_hub:
68
+ if d.data_files:
69
+ ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
70
+ else:
71
+ ds = load_dataset(d.path, streaming=True)
72
  else:
73
+ fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
74
+ ds = load_dataset("json", data_files=fp, streaming=True, split=None)
75
+ if not ds:
76
+ raise Exception("unhandled dataset load")
77
 
78
  if d.type == "alpaca":
79
  ds_strategy = AlpacaPromptTokenizingStrategy(
 
119
  seq_length=max_packed_sequence_len,
120
  )
121
  logging.info("merging, packing, shuffling, and splitting master dataset")
122
+ # TODO don't split dataset here, shuffle and save first, then split, that way we can
123
+ # re-split when loading again
124
  dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
125
  test_size=cfg.val_set_size, shuffle=True, seed=42
126
  )
src/axolotl/utils/models.py CHANGED
@@ -7,11 +7,16 @@ import torch
7
  import transformers
8
  from transformers import (
9
  AutoModelForCausalLM,
10
- LlamaForCausalLM,
11
- LlamaTokenizer,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
  )
 
 
 
 
 
 
 
15
 
16
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
17
 
@@ -95,7 +100,7 @@ def load_model(
95
  else True,
96
  )
97
  load_in_8bit = False
98
- elif is_llama_derived_model:
99
  model = LlamaForCausalLM.from_pretrained(
100
  base_model,
101
  load_in_8bit=cfg.load_in_8bit,
@@ -130,7 +135,7 @@ def load_model(
130
 
131
  if not tokenizer:
132
  try:
133
- if is_llama_derived_model:
134
  tokenizer = LlamaTokenizer.from_pretrained(model)
135
  else:
136
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
 
7
  import transformers
8
  from transformers import (
9
  AutoModelForCausalLM,
 
 
10
  AutoTokenizer,
11
  PreTrainedModel,
12
  )
13
+ try:
14
+ from transformers import (
15
+ LlamaForCausalLM,
16
+ LlamaTokenizer,
17
+ )
18
+ except:
19
+ logging.warning("This version of transformers does not support Llama. Consider upgrading.")
20
 
21
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
22
 
 
100
  else True,
101
  )
102
  load_in_8bit = False
103
+ elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
  model = LlamaForCausalLM.from_pretrained(
105
  base_model,
106
  load_in_8bit=cfg.load_in_8bit,
 
135
 
136
  if not tokenizer:
137
  try:
138
+ if is_llama_derived_model and "LlamaTokenizer" in globals():
139
  tokenizer = LlamaTokenizer.from_pretrained(model)
140
  else:
141
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)