fix sharegpt handling from hf, don't worry about loading llama if using earlier transformers release
Browse files- configs/llama_65B_alpaca.yml +4 -1
- src/axolotl/prompters.py +4 -0
- src/axolotl/utils/data.py +12 -2
- src/axolotl/utils/models.py +9 -4
configs/llama_65B_alpaca.yml
CHANGED
@@ -5,7 +5,8 @@ load_in_8bit: true
|
|
5 |
datasets:
|
6 |
- path: data/alpaca_data_gpt4.jsonl
|
7 |
type: alpaca
|
8 |
-
- path:
|
|
|
9 |
type: sharegpt
|
10 |
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
11 |
type: gpteacher
|
@@ -30,6 +31,8 @@ wandb_log_model: checkpoint
|
|
30 |
output_dir: ./lora-llama-alpaca
|
31 |
batch_size: 128
|
32 |
micro_batch_size: 16
|
|
|
|
|
33 |
num_epochs: 5
|
34 |
learning_rate: 0.00003
|
35 |
train_on_inputs: false
|
|
|
5 |
datasets:
|
6 |
- path: data/alpaca_data_gpt4.jsonl
|
7 |
type: alpaca
|
8 |
+
- path: anon8231489123/ShareGPT_Vicuna_unfiltered
|
9 |
+
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
|
10 |
type: sharegpt
|
11 |
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
12 |
type: gpteacher
|
|
|
31 |
output_dir: ./lora-llama-alpaca
|
32 |
batch_size: 128
|
33 |
micro_batch_size: 16
|
34 |
+
warmup_steps: 1000
|
35 |
+
save_steps:
|
36 |
num_epochs: 5
|
37 |
learning_rate: 0.00003
|
38 |
train_on_inputs: false
|
src/axolotl/prompters.py
CHANGED
@@ -128,6 +128,10 @@ conv_vicuna_v1_1 = Conversation(
|
|
128 |
|
129 |
class ShareGPTPrompter:
|
130 |
def build_prompt(self, source, tokenizer):
|
|
|
|
|
|
|
|
|
131 |
if len(source) < 2:
|
132 |
# If there isn't a back and forth conversation, ignore it
|
133 |
# also happens on the data splitting leaving empty conversations
|
|
|
128 |
|
129 |
class ShareGPTPrompter:
|
130 |
def build_prompt(self, source, tokenizer):
|
131 |
+
# ignore the system prompt if provided
|
132 |
+
if source[0]["from"] == "system":
|
133 |
+
source.pop(0)
|
134 |
+
|
135 |
if len(source) < 2:
|
136 |
# If there isn't a back and forth conversation, ignore it
|
137 |
# also happens on the data splitting leaving empty conversations
|
src/axolotl/utils/data.py
CHANGED
@@ -3,6 +3,7 @@ from hashlib import md5
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
|
|
|
6 |
|
7 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
8 |
from axolotl.prompt_tokenizers import (
|
@@ -50,6 +51,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
50 |
logging.info("Loading raw datasets...")
|
51 |
datasets = []
|
52 |
for d in cfg.datasets:
|
|
|
53 |
ds_from_hub = False
|
54 |
try:
|
55 |
load_dataset(d.path, streaming=True)
|
@@ -63,9 +65,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
63 |
"json", data_files=d.path, streaming=True, split=None
|
64 |
)
|
65 |
elif ds_from_hub:
|
66 |
-
|
|
|
|
|
|
|
67 |
else:
|
68 |
-
|
|
|
|
|
|
|
69 |
|
70 |
if d.type == "alpaca":
|
71 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
@@ -111,6 +119,8 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
111 |
seq_length=max_packed_sequence_len,
|
112 |
)
|
113 |
logging.info("merging, packing, shuffling, and splitting master dataset")
|
|
|
|
|
114 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
115 |
test_size=cfg.val_set_size, shuffle=True, seed=42
|
116 |
)
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
9 |
from axolotl.prompt_tokenizers import (
|
|
|
51 |
logging.info("Loading raw datasets...")
|
52 |
datasets = []
|
53 |
for d in cfg.datasets:
|
54 |
+
ds = None
|
55 |
ds_from_hub = False
|
56 |
try:
|
57 |
load_dataset(d.path, streaming=True)
|
|
|
65 |
"json", data_files=d.path, streaming=True, split=None
|
66 |
)
|
67 |
elif ds_from_hub:
|
68 |
+
if d.data_files:
|
69 |
+
ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
|
70 |
+
else:
|
71 |
+
ds = load_dataset(d.path, streaming=True)
|
72 |
else:
|
73 |
+
fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
|
74 |
+
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
75 |
+
if not ds:
|
76 |
+
raise Exception("unhandled dataset load")
|
77 |
|
78 |
if d.type == "alpaca":
|
79 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
|
119 |
seq_length=max_packed_sequence_len,
|
120 |
)
|
121 |
logging.info("merging, packing, shuffling, and splitting master dataset")
|
122 |
+
# TODO don't split dataset here, shuffle and save first, then split, that way we can
|
123 |
+
# re-split when loading again
|
124 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
125 |
test_size=cfg.val_set_size, shuffle=True, seed=42
|
126 |
)
|
src/axolotl/utils/models.py
CHANGED
@@ -7,11 +7,16 @@ import torch
|
|
7 |
import transformers
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
-
LlamaForCausalLM,
|
11 |
-
LlamaTokenizer,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
14 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
17 |
|
@@ -95,7 +100,7 @@ def load_model(
|
|
95 |
else True,
|
96 |
)
|
97 |
load_in_8bit = False
|
98 |
-
elif is_llama_derived_model:
|
99 |
model = LlamaForCausalLM.from_pretrained(
|
100 |
base_model,
|
101 |
load_in_8bit=cfg.load_in_8bit,
|
@@ -130,7 +135,7 @@ def load_model(
|
|
130 |
|
131 |
if not tokenizer:
|
132 |
try:
|
133 |
-
if is_llama_derived_model:
|
134 |
tokenizer = LlamaTokenizer.from_pretrained(model)
|
135 |
else:
|
136 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
|
|
7 |
import transformers
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
|
|
|
|
10 |
AutoTokenizer,
|
11 |
PreTrainedModel,
|
12 |
)
|
13 |
+
try:
|
14 |
+
from transformers import (
|
15 |
+
LlamaForCausalLM,
|
16 |
+
LlamaTokenizer,
|
17 |
+
)
|
18 |
+
except:
|
19 |
+
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
|
20 |
|
21 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
22 |
|
|
|
100 |
else True,
|
101 |
)
|
102 |
load_in_8bit = False
|
103 |
+
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
104 |
model = LlamaForCausalLM.from_pretrained(
|
105 |
base_model,
|
106 |
load_in_8bit=cfg.load_in_8bit,
|
|
|
135 |
|
136 |
if not tokenizer:
|
137 |
try:
|
138 |
+
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
139 |
tokenizer = LlamaTokenizer.from_pretrained(model)
|
140 |
else:
|
141 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|