apply black formatting
Browse files- src/axolotl/prompt_strategies/__init__.py +1 -0
- src/axolotl/prompt_strategies/alpaca_chat.py +12 -3
- src/axolotl/prompt_strategies/alpaca_instruct.py +4 -1
- src/axolotl/prompt_strategies/creative_acr.py +18 -6
- src/axolotl/prompt_strategies/pygmalion.py +22 -12
- src/axolotl/prompt_tokenizers.py +53 -36
- src/axolotl/prompters.py +48 -15
- src/axolotl/utils/data.py +68 -26
- src/axolotl/utils/models.py +10 -3
- src/axolotl/utils/trainer.py +12 -6
src/axolotl/prompt_strategies/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import importlib
|
2 |
|
|
|
3 |
def load(strategy, tokenizer, cfg):
|
4 |
try:
|
5 |
load_fn = "load"
|
|
|
1 |
import importlib
|
2 |
|
3 |
+
|
4 |
def load(strategy, tokenizer, cfg):
|
5 |
try:
|
6 |
load_fn = "load"
|
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
-
from axolotl.prompt_tokenizers import
|
|
|
|
|
|
|
2 |
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
3 |
|
4 |
|
5 |
def load(tokenizer, cfg):
|
6 |
return AlpacaPromptTokenizingStrategy(
|
7 |
-
AlpacaPrompter(PromptStyle.chat.value),
|
|
|
|
|
|
|
8 |
)
|
9 |
|
10 |
|
@@ -19,5 +25,8 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
19 |
|
20 |
def load_qa(tokenizer, cfg):
|
21 |
return AlpacaQAPromptTokenizingStrategy(
|
22 |
-
AlpacaPrompter(PromptStyle.chat.value),
|
|
|
|
|
|
|
23 |
)
|
|
|
1 |
+
from axolotl.prompt_tokenizers import (
|
2 |
+
AlpacaPromptTokenizingStrategy,
|
3 |
+
InstructionPromptTokenizingStrategy,
|
4 |
+
)
|
5 |
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
6 |
|
7 |
|
8 |
def load(tokenizer, cfg):
|
9 |
return AlpacaPromptTokenizingStrategy(
|
10 |
+
AlpacaPrompter(PromptStyle.chat.value),
|
11 |
+
tokenizer,
|
12 |
+
cfg.train_on_inputs,
|
13 |
+
cfg.sequence_len,
|
14 |
)
|
15 |
|
16 |
|
|
|
25 |
|
26 |
def load_qa(tokenizer, cfg):
|
27 |
return AlpacaQAPromptTokenizingStrategy(
|
28 |
+
AlpacaPrompter(PromptStyle.chat.value),
|
29 |
+
tokenizer,
|
30 |
+
cfg.train_on_inputs,
|
31 |
+
cfg.sequence_len,
|
32 |
)
|
src/axolotl/prompt_strategies/alpaca_instruct.py
CHANGED
@@ -4,5 +4,8 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
|
|
4 |
|
5 |
def load(tokenizer, cfg):
|
6 |
return AlpacaPromptTokenizingStrategy(
|
7 |
-
AlpacaPrompter(PromptStyle.instruct),
|
|
|
|
|
|
|
8 |
)
|
|
|
4 |
|
5 |
def load(tokenizer, cfg):
|
6 |
return AlpacaPromptTokenizingStrategy(
|
7 |
+
AlpacaPrompter(PromptStyle.instruct),
|
8 |
+
tokenizer,
|
9 |
+
cfg.train_on_inputs,
|
10 |
+
cfg.sequence_len,
|
11 |
)
|
src/axolotl/prompt_strategies/creative_acr.py
CHANGED
@@ -7,7 +7,9 @@ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
|
7 |
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
8 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
9 |
question = prompt["instruction"]
|
10 |
-
answer = prompt[
|
|
|
|
|
11 |
return (
|
12 |
question,
|
13 |
"",
|
@@ -48,8 +50,12 @@ Answer: {answer}
|
|
48 |
"""
|
49 |
|
50 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
51 |
-
scores = yaml.dump(
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
evaluation = scores + critiques
|
54 |
question = prompt["instruction"]
|
55 |
answer = prompt["answer"]
|
@@ -76,13 +82,19 @@ Evaluation:
|
|
76 |
"""
|
77 |
|
78 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
79 |
-
scores = yaml.dump(
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
evaluation = scores + critiques
|
82 |
question = prompt["instruction"]
|
83 |
answer = prompt["answer"]
|
84 |
return (
|
85 |
-
self.user_prompt.format(
|
|
|
|
|
86 |
"",
|
87 |
prompt["revision"],
|
88 |
)
|
|
|
7 |
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
8 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
9 |
question = prompt["instruction"]
|
10 |
+
answer = prompt[
|
11 |
+
"revision"
|
12 |
+
] # don't use prompt[answer], that's data we don't want in the dataset
|
13 |
return (
|
14 |
question,
|
15 |
"",
|
|
|
50 |
"""
|
51 |
|
52 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
53 |
+
scores = yaml.dump(
|
54 |
+
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
55 |
+
)
|
56 |
+
critiques = yaml.dump(
|
57 |
+
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
58 |
+
)
|
59 |
evaluation = scores + critiques
|
60 |
question = prompt["instruction"]
|
61 |
answer = prompt["answer"]
|
|
|
82 |
"""
|
83 |
|
84 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
85 |
+
scores = yaml.dump(
|
86 |
+
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
87 |
+
)
|
88 |
+
critiques = yaml.dump(
|
89 |
+
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
90 |
+
)
|
91 |
evaluation = scores + critiques
|
92 |
question = prompt["instruction"]
|
93 |
answer = prompt["answer"]
|
94 |
return (
|
95 |
+
self.user_prompt.format(
|
96 |
+
question=question, answer=answer, evaluation=evaluation
|
97 |
+
),
|
98 |
"",
|
99 |
prompt["revision"],
|
100 |
)
|
src/axolotl/prompt_strategies/pygmalion.py
CHANGED
@@ -30,20 +30,34 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
30 |
# this should include a bos token, no eos token, strip trailing "\n<START>"
|
31 |
if message.endswith("\n<START>"):
|
32 |
message = message[:-8]
|
33 |
-
res = self._tokenize(
|
|
|
|
|
|
|
|
|
34 |
# everything from this is masked out from the labels
|
35 |
-
labels = [
|
36 |
elif role == "human":
|
37 |
prefix = "<|user|>"
|
38 |
-
res = self._tokenize(
|
|
|
|
|
|
|
|
|
39 |
# everything from this is masked out from the labels
|
40 |
-
labels = [
|
41 |
elif role == "bot":
|
42 |
prefix = "<|model|>"
|
43 |
-
res = self._tokenize(
|
|
|
|
|
|
|
|
|
44 |
# mask out the prefix token, rest is not masked out from labels
|
45 |
# make sure we create the labels first, otherwise we get incorrect lengths
|
46 |
-
labels = [
|
|
|
|
|
47 |
else:
|
48 |
logging.warning(f"unknown role in conversation: {role}")
|
49 |
res = defaultdict(lambda: [])
|
@@ -51,8 +65,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
51 |
input_len = len(input_ids)
|
52 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
53 |
result["attention_mask"][current_len : current_len + input_len] = [
|
54 |
-
1 if x != self.tokenizer.pad_token_id else 0
|
55 |
-
for x in input_ids
|
56 |
]
|
57 |
result["labels"][current_len : current_len + input_len] = labels
|
58 |
current_len += input_len
|
@@ -74,10 +87,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
74 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
75 |
result["attention_mask"].append(1)
|
76 |
|
77 |
-
if
|
78 |
-
result["input_ids"][0] == self.tokenizer.bos_token_id
|
79 |
-
and strip_bos_token
|
80 |
-
):
|
81 |
result["input_ids"] = result["input_ids"][1:]
|
82 |
result["attention_mask"] = result["attention_mask"][1:]
|
83 |
|
|
|
30 |
# this should include a bos token, no eos token, strip trailing "\n<START>"
|
31 |
if message.endswith("\n<START>"):
|
32 |
message = message[:-8]
|
33 |
+
res = self._tokenize(
|
34 |
+
prefix + "Persona: " + message.strip(),
|
35 |
+
add_eos_token=False,
|
36 |
+
strip_bos_token=False,
|
37 |
+
)
|
38 |
# everything from this is masked out from the labels
|
39 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
40 |
elif role == "human":
|
41 |
prefix = "<|user|>"
|
42 |
+
res = self._tokenize(
|
43 |
+
prefix + " " + message.strip(),
|
44 |
+
add_eos_token=False,
|
45 |
+
strip_bos_token=True,
|
46 |
+
)
|
47 |
# everything from this is masked out from the labels
|
48 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
49 |
elif role == "bot":
|
50 |
prefix = "<|model|>"
|
51 |
+
res = self._tokenize(
|
52 |
+
prefix + " " + message.strip(),
|
53 |
+
add_eos_token=True,
|
54 |
+
strip_bos_token=True,
|
55 |
+
)
|
56 |
# mask out the prefix token, rest is not masked out from labels
|
57 |
# make sure we create the labels first, otherwise we get incorrect lengths
|
58 |
+
labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
|
59 |
+
*copy.deepcopy(res["input_ids"])
|
60 |
+
][len(self.bot_prefix_token_ids) :]
|
61 |
else:
|
62 |
logging.warning(f"unknown role in conversation: {role}")
|
63 |
res = defaultdict(lambda: [])
|
|
|
65 |
input_len = len(input_ids)
|
66 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
67 |
result["attention_mask"][current_len : current_len + input_len] = [
|
68 |
+
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
|
|
69 |
]
|
70 |
result["labels"][current_len : current_len + input_len] = labels
|
71 |
current_len += input_len
|
|
|
87 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
88 |
result["attention_mask"].append(1)
|
89 |
|
90 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
|
|
|
|
|
91 |
result["input_ids"] = result["input_ids"][1:]
|
92 |
result["attention_mask"] = result["attention_mask"][1:]
|
93 |
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -59,10 +59,14 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
59 |
full_prompt = self._build_full_prompt(instruction, input, response)
|
60 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
61 |
if not self.train_on_inputs:
|
62 |
-
user_prompt = next(
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
67 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
68 |
# TODO this could be sped up using numpy array slicing
|
@@ -73,11 +77,15 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
73 |
return tokenized_full_prompt
|
74 |
|
75 |
def _build_full_prompt(self, instruction, input, response):
|
76 |
-
return next(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
83 |
result = self.tokenizer(
|
@@ -95,10 +103,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
95 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
96 |
result["attention_mask"].append(1)
|
97 |
|
98 |
-
if
|
99 |
-
result["input_ids"][0] == self.tokenizer.bos_token_id
|
100 |
-
and strip_bos_token
|
101 |
-
):
|
102 |
result["input_ids"] = result["input_ids"][1:]
|
103 |
result["attention_mask"] = result["attention_mask"][1:]
|
104 |
|
@@ -201,10 +206,14 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
201 |
)
|
202 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
203 |
if not self.train_on_inputs:
|
204 |
-
user_prompt = next(
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
209 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
210 |
# TODO this could be sped up using numpy array slicing
|
@@ -215,13 +224,17 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
215 |
return tokenized_full_prompt
|
216 |
|
217 |
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
218 |
-
return next(
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
225 |
|
226 |
def _tokenize(self, prompt, add_eos_token=True):
|
227 |
result = self.tokenizer(
|
@@ -265,21 +278,27 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
265 |
user_token = self._get_user_token()
|
266 |
assistant_token = self._get_assistant_token()
|
267 |
try:
|
268 |
-
for i, part in enumerate(
|
|
|
|
|
269 |
if isinstance(part, tuple):
|
270 |
if part[0] == "USER:":
|
271 |
part = part[0] + part[1] if not user_token else part[1]
|
272 |
# this is still the user query, we should
|
273 |
-
res = self._tokenize(
|
|
|
|
|
274 |
if user_token:
|
275 |
res["input_ids"] = [user_token, *res["input_ids"]]
|
276 |
# everything from this is masked out from the labels
|
277 |
-
labels = [
|
278 |
elif part[0] == "ASSISTANT:":
|
279 |
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
280 |
part = part[0] + part[1] if not assistant_token else part[1]
|
281 |
# this should be the assistent response, should end with an eos token
|
282 |
-
res = self._tokenize(
|
|
|
|
|
283 |
if assistant_token:
|
284 |
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
285 |
# not masked out from labels
|
@@ -288,15 +307,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
288 |
logging.warning("unhandled role: " + part[0])
|
289 |
else:
|
290 |
# this is only ever the first part, should include the bos token and the user query
|
291 |
-
res = self._tokenize(
|
|
|
|
|
292 |
# everything from this is masked out from the labels
|
293 |
-
labels = [
|
294 |
input_ids = res["input_ids"]
|
295 |
input_len = len(input_ids)
|
296 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
297 |
result["attention_mask"][current_len : current_len + input_len] = [
|
298 |
-
1 if x != self.tokenizer.pad_token_id else 0
|
299 |
-
for x in input_ids
|
300 |
]
|
301 |
result["labels"][current_len : current_len + input_len] = labels
|
302 |
current_len += input_len
|
@@ -320,10 +340,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
320 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
321 |
result["attention_mask"].append(1)
|
322 |
|
323 |
-
if
|
324 |
-
result["input_ids"][0] == self.tokenizer.bos_token_id
|
325 |
-
and strip_bos_token
|
326 |
-
):
|
327 |
result["input_ids"] = result["input_ids"][1:]
|
328 |
result["attention_mask"] = result["attention_mask"][1:]
|
329 |
|
|
|
59 |
full_prompt = self._build_full_prompt(instruction, input, response)
|
60 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
61 |
if not self.train_on_inputs:
|
62 |
+
user_prompt = next(
|
63 |
+
iter(
|
64 |
+
self.prompter.build_prompt(
|
65 |
+
instruction,
|
66 |
+
input,
|
67 |
+
)
|
68 |
+
)
|
69 |
+
)
|
70 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
71 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
72 |
# TODO this could be sped up using numpy array slicing
|
|
|
77 |
return tokenized_full_prompt
|
78 |
|
79 |
def _build_full_prompt(self, instruction, input, response):
|
80 |
+
return next(
|
81 |
+
iter(
|
82 |
+
self.prompter.build_prompt(
|
83 |
+
instruction,
|
84 |
+
input,
|
85 |
+
response,
|
86 |
+
)
|
87 |
+
)
|
88 |
+
)
|
89 |
|
90 |
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
91 |
result = self.tokenizer(
|
|
|
103 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
104 |
result["attention_mask"].append(1)
|
105 |
|
106 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
|
|
|
|
|
107 |
result["input_ids"] = result["input_ids"][1:]
|
108 |
result["attention_mask"] = result["attention_mask"][1:]
|
109 |
|
|
|
206 |
)
|
207 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
208 |
if not self.train_on_inputs:
|
209 |
+
user_prompt = next(
|
210 |
+
iter(
|
211 |
+
self.prompter.build_prompt(
|
212 |
+
instruction,
|
213 |
+
input,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
)
|
217 |
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
218 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
219 |
# TODO this could be sped up using numpy array slicing
|
|
|
224 |
return tokenized_full_prompt
|
225 |
|
226 |
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
227 |
+
return next(
|
228 |
+
iter(
|
229 |
+
self.prompter.build_prompt(
|
230 |
+
instruction,
|
231 |
+
input,
|
232 |
+
output,
|
233 |
+
reflection,
|
234 |
+
corrected,
|
235 |
+
)
|
236 |
+
)
|
237 |
+
)
|
238 |
|
239 |
def _tokenize(self, prompt, add_eos_token=True):
|
240 |
result = self.tokenizer(
|
|
|
278 |
user_token = self._get_user_token()
|
279 |
assistant_token = self._get_assistant_token()
|
280 |
try:
|
281 |
+
for i, part in enumerate(
|
282 |
+
self.prompter.build_prompt(prompt["conversations"])
|
283 |
+
):
|
284 |
if isinstance(part, tuple):
|
285 |
if part[0] == "USER:":
|
286 |
part = part[0] + part[1] if not user_token else part[1]
|
287 |
# this is still the user query, we should
|
288 |
+
res = self._tokenize(
|
289 |
+
part.strip(), add_eos_token=False, strip_bos_token=True
|
290 |
+
)
|
291 |
if user_token:
|
292 |
res["input_ids"] = [user_token, *res["input_ids"]]
|
293 |
# everything from this is masked out from the labels
|
294 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
295 |
elif part[0] == "ASSISTANT:":
|
296 |
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
297 |
part = part[0] + part[1] if not assistant_token else part[1]
|
298 |
# this should be the assistent response, should end with an eos token
|
299 |
+
res = self._tokenize(
|
300 |
+
part.strip(), add_eos_token=True, strip_bos_token=True
|
301 |
+
)
|
302 |
if assistant_token:
|
303 |
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
304 |
# not masked out from labels
|
|
|
307 |
logging.warning("unhandled role: " + part[0])
|
308 |
else:
|
309 |
# this is only ever the first part, should include the bos token and the user query
|
310 |
+
res = self._tokenize(
|
311 |
+
part.strip(), add_eos_token=False, strip_bos_token=False
|
312 |
+
)
|
313 |
# everything from this is masked out from the labels
|
314 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
315 |
input_ids = res["input_ids"]
|
316 |
input_len = len(input_ids)
|
317 |
result["input_ids"][current_len : current_len + input_len] = input_ids
|
318 |
result["attention_mask"][current_len : current_len + input_len] = [
|
319 |
+
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
|
|
|
320 |
]
|
321 |
result["labels"][current_len : current_len + input_len] = labels
|
322 |
current_len += input_len
|
|
|
340 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
341 |
result["attention_mask"].append(1)
|
342 |
|
343 |
+
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
|
|
|
|
|
|
344 |
result["input_ids"] = result["input_ids"][1:]
|
345 |
result["attention_mask"] = result["attention_mask"][1:]
|
346 |
|
src/axolotl/prompters.py
CHANGED
@@ -23,12 +23,22 @@ class AlpacaPrompter:
|
|
23 |
|
24 |
def match_prompt_style(self):
|
25 |
if self.prompt_style == PromptStyle.instruct.value:
|
26 |
-
self.prompt_input =
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
self.response_split = "### Response:"
|
29 |
if self.prompt_style == PromptStyle.chat.value:
|
30 |
-
self.prompt_input =
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
self.response_split = "ASSISTANT:"
|
33 |
|
34 |
def build_prompt(
|
@@ -55,12 +65,15 @@ class UnpromptedPrompter(AlpacaPrompter):
|
|
55 |
system_prompt = ""
|
56 |
system_no_input_prompt = ""
|
57 |
|
|
|
58 |
class JeopardyPrompter(AlpacaPrompter):
|
59 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
60 |
|
61 |
|
62 |
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
63 |
-
system_prompt =
|
|
|
|
|
64 |
|
65 |
|
66 |
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
@@ -68,11 +81,15 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|
68 |
|
69 |
|
70 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
71 |
-
prompt_no_input =
|
|
|
|
|
72 |
|
73 |
|
74 |
class CompletionPrompter(AlpacaPrompter):
|
75 |
-
def build_prompt(
|
|
|
|
|
76 |
yield instruction
|
77 |
|
78 |
def get_response(self, output: str) -> str:
|
@@ -91,7 +108,9 @@ class ReflectAlpacaPrompter:
|
|
91 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
92 |
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
93 |
|
94 |
-
prompt_input =
|
|
|
|
|
95 |
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
96 |
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
97 |
response_split = "### Response:"
|
@@ -102,14 +121,26 @@ class ReflectAlpacaPrompter:
|
|
102 |
|
103 |
def match_prompt_style(self):
|
104 |
if self.prompt_style == PromptStyle.instruct.value:
|
105 |
-
self.prompt_input =
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
108 |
self.response_split = "### Final Response:"
|
109 |
if self.prompt_style == PromptStyle.chat.value:
|
110 |
-
self.prompt_input =
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
self.response_split = "ASSISTANT:"
|
114 |
|
115 |
def build_prompt(
|
@@ -167,7 +198,7 @@ class Conversation:
|
|
167 |
yield (role + ":", " " + message)
|
168 |
else:
|
169 |
logging.warning("role with empty message: " + role)
|
170 |
-
yield (role + ":",
|
171 |
|
172 |
def copy(self):
|
173 |
return Conversation(
|
@@ -199,7 +230,9 @@ conv_vicuna_v1_1 = Conversation(
|
|
199 |
class ShareGPTPrompter:
|
200 |
def __init__(self, prompt_style=None):
|
201 |
if prompt_style != PromptStyle.chat.value:
|
202 |
-
raise Exception(
|
|
|
|
|
203 |
|
204 |
# def match_prompt_style(self):
|
205 |
# if self.prompt_style == PromptStyle.chat.value:
|
|
|
23 |
|
24 |
def match_prompt_style(self):
|
25 |
if self.prompt_style == PromptStyle.instruct.value:
|
26 |
+
self.prompt_input = (
|
27 |
+
self.system_prompt
|
28 |
+
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
29 |
+
)
|
30 |
+
self.prompt_no_input = (
|
31 |
+
self.system_no_input_prompt
|
32 |
+
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
33 |
+
)
|
34 |
self.response_split = "### Response:"
|
35 |
if self.prompt_style == PromptStyle.chat.value:
|
36 |
+
self.prompt_input = (
|
37 |
+
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
38 |
+
)
|
39 |
+
self.prompt_no_input = (
|
40 |
+
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
41 |
+
)
|
42 |
self.response_split = "ASSISTANT:"
|
43 |
|
44 |
def build_prompt(
|
|
|
65 |
system_prompt = ""
|
66 |
system_no_input_prompt = ""
|
67 |
|
68 |
+
|
69 |
class JeopardyPrompter(AlpacaPrompter):
|
70 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
71 |
|
72 |
|
73 |
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
74 |
+
system_prompt = (
|
75 |
+
"Choose the answer that best answers the question. Explain your reasoning."
|
76 |
+
)
|
77 |
|
78 |
|
79 |
class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|
|
81 |
|
82 |
|
83 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
84 |
+
prompt_no_input = (
|
85 |
+
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
86 |
+
)
|
87 |
|
88 |
|
89 |
class CompletionPrompter(AlpacaPrompter):
|
90 |
+
def build_prompt(
|
91 |
+
self, instruction: str, input=None, output=None
|
92 |
+
) -> Generator[str, None, None]:
|
93 |
yield instruction
|
94 |
|
95 |
def get_response(self, output: str) -> str:
|
|
|
108 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
109 |
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
|
110 |
|
111 |
+
prompt_input = (
|
112 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
113 |
+
)
|
114 |
prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
|
115 |
agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
116 |
response_split = "### Response:"
|
|
|
121 |
|
122 |
def match_prompt_style(self):
|
123 |
if self.prompt_style == PromptStyle.instruct.value:
|
124 |
+
self.prompt_input = (
|
125 |
+
self.system_prompt
|
126 |
+
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
127 |
+
)
|
128 |
+
self.prompt_no_input = (
|
129 |
+
self.system_no_input_prompt
|
130 |
+
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
131 |
+
)
|
132 |
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
133 |
self.response_split = "### Final Response:"
|
134 |
if self.prompt_style == PromptStyle.chat.value:
|
135 |
+
self.prompt_input = (
|
136 |
+
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
137 |
+
)
|
138 |
+
self.prompt_no_input = (
|
139 |
+
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
140 |
+
)
|
141 |
+
self.agent_label = (
|
142 |
+
"\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
|
143 |
+
)
|
144 |
self.response_split = "ASSISTANT:"
|
145 |
|
146 |
def build_prompt(
|
|
|
198 |
yield (role + ":", " " + message)
|
199 |
else:
|
200 |
logging.warning("role with empty message: " + role)
|
201 |
+
yield (role + ":",)
|
202 |
|
203 |
def copy(self):
|
204 |
return Conversation(
|
|
|
230 |
class ShareGPTPrompter:
|
231 |
def __init__(self, prompt_style=None):
|
232 |
if prompt_style != PromptStyle.chat.value:
|
233 |
+
raise Exception(
|
234 |
+
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
235 |
+
)
|
236 |
|
237 |
# def match_prompt_style(self):
|
238 |
# if self.prompt_style == PromptStyle.chat.value:
|
src/axolotl/utils/data.py
CHANGED
@@ -7,7 +7,8 @@ from datasets import (
|
|
7 |
load_dataset,
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
-
concatenate_datasets,
|
|
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from transformers import PreTrainedTokenizerBase
|
@@ -33,11 +34,14 @@ from axolotl.prompters import (
|
|
33 |
JeopardyPrompter,
|
34 |
CompletionPrompter,
|
35 |
MultipleChoiceExplainPrompter,
|
36 |
-
SummarizeTLDRPrompter,
|
|
|
37 |
)
|
38 |
|
39 |
|
40 |
-
def load_tokenized_prepared_datasets(
|
|
|
|
|
41 |
tokenizer_name = tokenizer.__class__.__name__
|
42 |
ds_hash = str(
|
43 |
md5(
|
@@ -45,7 +49,8 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
45 |
str(cfg.sequence_len)
|
46 |
+ "@"
|
47 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
48 |
-
+ "|"
|
|
|
49 |
).encode("utf-8")
|
50 |
).hexdigest()
|
51 |
)
|
@@ -57,7 +62,9 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
57 |
dataset = None
|
58 |
try:
|
59 |
if cfg.push_dataset_to_hub:
|
60 |
-
dataset = load_dataset(
|
|
|
|
|
61 |
dataset = dataset["train"]
|
62 |
except:
|
63 |
pass
|
@@ -88,7 +95,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
88 |
)
|
89 |
elif ds_from_hub:
|
90 |
if d.data_files:
|
91 |
-
ds = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
92 |
else:
|
93 |
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
94 |
else:
|
@@ -100,49 +112,65 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
100 |
raise Exception("unhandled dataset load")
|
101 |
# support for using a subset of the data
|
102 |
if d.shards:
|
103 |
-
ds = ds.shuffle(seed=42)["train"].shard(
|
104 |
-
num_shards=cfg.shards, index=0
|
105 |
-
)
|
106 |
d_type = d.type
|
107 |
d_type_split = d_type.split(":")
|
108 |
d_base_type = d_type_split[0]
|
109 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
110 |
-
if
|
111 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
112 |
datasets.append(ds_wrapper)
|
113 |
elif d_base_type == "alpaca":
|
114 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
115 |
-
AlpacaPrompter(d_prompt_style),
|
|
|
|
|
|
|
116 |
)
|
117 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
118 |
datasets.append(ds_wrapper)
|
119 |
elif d_base_type == "explainchoice":
|
120 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
121 |
-
MultipleChoiceExplainPrompter(d_prompt_style),
|
|
|
|
|
|
|
122 |
)
|
123 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
124 |
datasets.append(ds_wrapper)
|
125 |
elif d_base_type == "concisechoice":
|
126 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
127 |
-
MultipleChoiceConcisePrompter(d_prompt_style),
|
|
|
|
|
|
|
128 |
)
|
129 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
130 |
datasets.append(ds_wrapper)
|
131 |
elif d_base_type == "summarizetldr":
|
132 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
133 |
-
SummarizeTLDRPrompter(d_prompt_style),
|
|
|
|
|
|
|
134 |
)
|
135 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
136 |
datasets.append(ds_wrapper)
|
137 |
elif d_base_type == "jeopardy":
|
138 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
139 |
-
JeopardyPrompter(d_prompt_style),
|
|
|
|
|
|
|
140 |
)
|
141 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
142 |
datasets.append(ds_wrapper)
|
143 |
elif d_base_type == "oasst":
|
144 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
145 |
-
AlpacaPrompter(d_prompt_style),
|
|
|
|
|
|
|
146 |
)
|
147 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
148 |
datasets.append(ds_wrapper)
|
@@ -166,7 +194,10 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
166 |
datasets.append(ds_wrapper)
|
167 |
elif d_base_type == "sharegpt":
|
168 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
169 |
-
ShareGPTPrompter(d_prompt_style),
|
|
|
|
|
|
|
170 |
)
|
171 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
172 |
datasets.append(ds_wrapper)
|
@@ -196,12 +227,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
196 |
logging.info(
|
197 |
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
198 |
)
|
199 |
-
dataset.push_to_hub(
|
|
|
|
|
200 |
|
201 |
return dataset
|
202 |
|
203 |
|
204 |
-
def load_prepare_datasets(
|
|
|
|
|
205 |
max_packed_sequence_len = (
|
206 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
207 |
)
|
@@ -221,7 +256,8 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|
221 |
+ str(max_packed_sequence_len)
|
222 |
+ seed
|
223 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
224 |
-
+ "|"
|
|
|
225 |
).encode("utf-8")
|
226 |
).hexdigest()
|
227 |
)
|
@@ -237,7 +273,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|
237 |
logging.info(
|
238 |
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
239 |
)
|
240 |
-
dataset = load_dataset(
|
|
|
|
|
241 |
dataset = dataset["train"]
|
242 |
except:
|
243 |
pass
|
@@ -254,7 +292,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|
254 |
logging.info(
|
255 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
256 |
)
|
257 |
-
dataset.push_to_hub(
|
|
|
|
|
258 |
else:
|
259 |
dataset = load_tokenized_prepared_datasets(
|
260 |
tokenizer, cfg, default_dataset_prepared_path
|
@@ -279,9 +319,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|
279 |
d
|
280 |
for d in dataset
|
281 |
if len(d["input_ids"]) < cfg.sequence_len
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
]
|
286 |
)
|
287 |
|
@@ -294,7 +334,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
|
|
294 |
logging.info(
|
295 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
296 |
)
|
297 |
-
dataset.push_to_hub(
|
|
|
|
|
298 |
else:
|
299 |
dataset = load_tokenized_prepared_datasets(
|
300 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
7 |
load_dataset,
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
+
concatenate_datasets,
|
11 |
+
DatasetDict,
|
12 |
)
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
from transformers import PreTrainedTokenizerBase
|
|
|
34 |
JeopardyPrompter,
|
35 |
CompletionPrompter,
|
36 |
MultipleChoiceExplainPrompter,
|
37 |
+
SummarizeTLDRPrompter,
|
38 |
+
MultipleChoiceConcisePrompter,
|
39 |
)
|
40 |
|
41 |
|
42 |
+
def load_tokenized_prepared_datasets(
|
43 |
+
tokenizer, cfg, default_dataset_prepared_path
|
44 |
+
) -> DatasetDict:
|
45 |
tokenizer_name = tokenizer.__class__.__name__
|
46 |
ds_hash = str(
|
47 |
md5(
|
|
|
49 |
str(cfg.sequence_len)
|
50 |
+ "@"
|
51 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
52 |
+
+ "|"
|
53 |
+
+ tokenizer_name
|
54 |
).encode("utf-8")
|
55 |
).hexdigest()
|
56 |
)
|
|
|
62 |
dataset = None
|
63 |
try:
|
64 |
if cfg.push_dataset_to_hub:
|
65 |
+
dataset = load_dataset(
|
66 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
67 |
+
)
|
68 |
dataset = dataset["train"]
|
69 |
except:
|
70 |
pass
|
|
|
95 |
)
|
96 |
elif ds_from_hub:
|
97 |
if d.data_files:
|
98 |
+
ds = load_dataset(
|
99 |
+
d.path,
|
100 |
+
streaming=False,
|
101 |
+
data_files=d.data_files,
|
102 |
+
use_auth_token=True,
|
103 |
+
)
|
104 |
else:
|
105 |
ds = load_dataset(d.path, streaming=False, use_auth_token=True)
|
106 |
else:
|
|
|
112 |
raise Exception("unhandled dataset load")
|
113 |
# support for using a subset of the data
|
114 |
if d.shards:
|
115 |
+
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
|
|
|
|
116 |
d_type = d.type
|
117 |
d_type_split = d_type.split(":")
|
118 |
d_base_type = d_type_split[0]
|
119 |
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
120 |
+
if ds_strategy := load(d.type, tokenizer, cfg):
|
121 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
122 |
datasets.append(ds_wrapper)
|
123 |
elif d_base_type == "alpaca":
|
124 |
ds_strategy = AlpacaPromptTokenizingStrategy(
|
125 |
+
AlpacaPrompter(d_prompt_style),
|
126 |
+
tokenizer,
|
127 |
+
cfg.train_on_inputs,
|
128 |
+
cfg.sequence_len,
|
129 |
)
|
130 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
131 |
datasets.append(ds_wrapper)
|
132 |
elif d_base_type == "explainchoice":
|
133 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
134 |
+
MultipleChoiceExplainPrompter(d_prompt_style),
|
135 |
+
tokenizer,
|
136 |
+
cfg.train_on_inputs,
|
137 |
+
cfg.sequence_len,
|
138 |
)
|
139 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
140 |
datasets.append(ds_wrapper)
|
141 |
elif d_base_type == "concisechoice":
|
142 |
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
143 |
+
MultipleChoiceConcisePrompter(d_prompt_style),
|
144 |
+
tokenizer,
|
145 |
+
cfg.train_on_inputs,
|
146 |
+
cfg.sequence_len,
|
147 |
)
|
148 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
149 |
datasets.append(ds_wrapper)
|
150 |
elif d_base_type == "summarizetldr":
|
151 |
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
152 |
+
SummarizeTLDRPrompter(d_prompt_style),
|
153 |
+
tokenizer,
|
154 |
+
cfg.train_on_inputs,
|
155 |
+
cfg.sequence_len,
|
156 |
)
|
157 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
158 |
datasets.append(ds_wrapper)
|
159 |
elif d_base_type == "jeopardy":
|
160 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
161 |
+
JeopardyPrompter(d_prompt_style),
|
162 |
+
tokenizer,
|
163 |
+
cfg.train_on_inputs,
|
164 |
+
cfg.sequence_len,
|
165 |
)
|
166 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
167 |
datasets.append(ds_wrapper)
|
168 |
elif d_base_type == "oasst":
|
169 |
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
170 |
+
AlpacaPrompter(d_prompt_style),
|
171 |
+
tokenizer,
|
172 |
+
cfg.train_on_inputs,
|
173 |
+
cfg.sequence_len,
|
174 |
)
|
175 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
176 |
datasets.append(ds_wrapper)
|
|
|
194 |
datasets.append(ds_wrapper)
|
195 |
elif d_base_type == "sharegpt":
|
196 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
197 |
+
ShareGPTPrompter(d_prompt_style),
|
198 |
+
tokenizer,
|
199 |
+
cfg.train_on_inputs,
|
200 |
+
cfg.sequence_len,
|
201 |
)
|
202 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
203 |
datasets.append(ds_wrapper)
|
|
|
227 |
logging.info(
|
228 |
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
229 |
)
|
230 |
+
dataset.push_to_hub(
|
231 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
232 |
+
)
|
233 |
|
234 |
return dataset
|
235 |
|
236 |
|
237 |
+
def load_prepare_datasets(
|
238 |
+
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
239 |
+
) -> (Dataset, Dataset):
|
240 |
max_packed_sequence_len = (
|
241 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
242 |
)
|
|
|
256 |
+ str(max_packed_sequence_len)
|
257 |
+ seed
|
258 |
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
259 |
+
+ "|"
|
260 |
+
+ tokenizer_name
|
261 |
).encode("utf-8")
|
262 |
).hexdigest()
|
263 |
)
|
|
|
273 |
logging.info(
|
274 |
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
275 |
)
|
276 |
+
dataset = load_dataset(
|
277 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
|
278 |
+
)
|
279 |
dataset = dataset["train"]
|
280 |
except:
|
281 |
pass
|
|
|
292 |
logging.info(
|
293 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
294 |
)
|
295 |
+
dataset.push_to_hub(
|
296 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
297 |
+
)
|
298 |
else:
|
299 |
dataset = load_tokenized_prepared_datasets(
|
300 |
tokenizer, cfg, default_dataset_prepared_path
|
|
|
319 |
d
|
320 |
for d in dataset
|
321 |
if len(d["input_ids"]) < cfg.sequence_len
|
322 |
+
and len(d["input_ids"]) > 0
|
323 |
+
and len(d["input_ids"]) == len(d["attention_mask"])
|
324 |
+
and len(d["input_ids"]) == len(d["labels"])
|
325 |
]
|
326 |
)
|
327 |
|
|
|
334 |
logging.info(
|
335 |
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
336 |
)
|
337 |
+
dataset.push_to_hub(
|
338 |
+
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
339 |
+
)
|
340 |
else:
|
341 |
dataset = load_tokenized_prepared_datasets(
|
342 |
tokenizer, cfg, default_dataset_prepared_path
|
src/axolotl/utils/models.py
CHANGED
@@ -11,7 +11,8 @@ from transformers import (
|
|
11 |
AutoModelForCausalLM,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
14 |
-
AutoConfig,
|
|
|
15 |
)
|
16 |
|
17 |
try:
|
@@ -244,7 +245,9 @@ def load_model(
|
|
244 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
245 |
model.resize_token_embeddings(embeddings_len)
|
246 |
|
247 |
-
if (
|
|
|
|
|
248 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
249 |
model = prepare_model_for_int8_training(model)
|
250 |
|
@@ -265,7 +268,11 @@ def load_model(
|
|
265 |
m.scales = m.scales.half()
|
266 |
m.bias = m.bias.half()
|
267 |
|
268 |
-
if
|
|
|
|
|
|
|
|
|
269 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
270 |
# so let's only set it for the 4bit, see
|
271 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
|
|
11 |
AutoModelForCausalLM,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
14 |
+
AutoConfig,
|
15 |
+
BitsAndBytesConfig,
|
16 |
)
|
17 |
|
18 |
try:
|
|
|
245 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
246 |
model.resize_token_embeddings(embeddings_len)
|
247 |
|
248 |
+
if (
|
249 |
+
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
250 |
+
) and not cfg.load_4bit:
|
251 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
252 |
model = prepare_model_for_int8_training(model)
|
253 |
|
|
|
268 |
m.scales = m.scales.half()
|
269 |
m.bias = m.bias.half()
|
270 |
|
271 |
+
if (
|
272 |
+
torch.cuda.device_count() > 1
|
273 |
+
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
274 |
+
and cfg.load_4bit
|
275 |
+
):
|
276 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
277 |
# so let's only set it for the 4bit, see
|
278 |
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
src/axolotl/utils/trainer.py
CHANGED
@@ -17,10 +17,12 @@ from axolotl.utils.callbacks import SavePeftModelCallback
|
|
17 |
|
18 |
|
19 |
class OneCycleLRSchedulerTrainer(Trainer):
|
20 |
-
def create_scheduler(
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
pct_start = num_warmup_steps / num_training_steps
|
25 |
|
26 |
self.lr_scheduler = OneCycleLR(
|
@@ -203,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
203 |
)
|
204 |
callbacks.append(early_stop_cb)
|
205 |
|
206 |
-
if cfg.local_rank == 0 and cfg.adapter ==
|
207 |
callbacks.append(SavePeftModelCallback)
|
208 |
|
209 |
data_collator_kwargs = {
|
@@ -214,7 +216,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
214 |
else:
|
215 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
216 |
|
217 |
-
trainer_cls =
|
|
|
|
|
|
|
|
|
218 |
trainer = trainer_cls(
|
219 |
model=model,
|
220 |
train_dataset=train_dataset,
|
|
|
17 |
|
18 |
|
19 |
class OneCycleLRSchedulerTrainer(Trainer):
|
20 |
+
def create_scheduler(
|
21 |
+
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
22 |
+
):
|
23 |
+
optimizer = self.optimizer if optimizer is None else optimizer
|
24 |
+
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
25 |
+
num_training_steps = num_training_steps
|
26 |
pct_start = num_warmup_steps / num_training_steps
|
27 |
|
28 |
self.lr_scheduler = OneCycleLR(
|
|
|
205 |
)
|
206 |
callbacks.append(early_stop_cb)
|
207 |
|
208 |
+
if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
|
209 |
callbacks.append(SavePeftModelCallback)
|
210 |
|
211 |
data_collator_kwargs = {
|
|
|
216 |
else:
|
217 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
218 |
|
219 |
+
trainer_cls = (
|
220 |
+
OneCycleLRSchedulerTrainer
|
221 |
+
if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
|
222 |
+
else transformers.Trainer
|
223 |
+
)
|
224 |
trainer = trainer_cls(
|
225 |
model=model,
|
226 |
train_dataset=train_dataset,
|