winglian commited on
Commit
ce34d64
1 Parent(s): ce694e2

apply black formatting

Browse files
src/axolotl/prompt_strategies/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  import importlib
2
 
 
3
  def load(strategy, tokenizer, cfg):
4
  try:
5
  load_fn = "load"
 
1
  import importlib
2
 
3
+
4
  def load(strategy, tokenizer, cfg):
5
  try:
6
  load_fn = "load"
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -1,10 +1,16 @@
1
- from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy
 
 
 
2
  from axolotl.prompters import AlpacaPrompter, PromptStyle
3
 
4
 
5
  def load(tokenizer, cfg):
6
  return AlpacaPromptTokenizingStrategy(
7
- AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
8
  )
9
 
10
 
@@ -19,5 +25,8 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
19
 
20
  def load_qa(tokenizer, cfg):
21
  return AlpacaQAPromptTokenizingStrategy(
22
- AlpacaPrompter(PromptStyle.chat.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
23
  )
 
1
+ from axolotl.prompt_tokenizers import (
2
+ AlpacaPromptTokenizingStrategy,
3
+ InstructionPromptTokenizingStrategy,
4
+ )
5
  from axolotl.prompters import AlpacaPrompter, PromptStyle
6
 
7
 
8
  def load(tokenizer, cfg):
9
  return AlpacaPromptTokenizingStrategy(
10
+ AlpacaPrompter(PromptStyle.chat.value),
11
+ tokenizer,
12
+ cfg.train_on_inputs,
13
+ cfg.sequence_len,
14
  )
15
 
16
 
 
25
 
26
  def load_qa(tokenizer, cfg):
27
  return AlpacaQAPromptTokenizingStrategy(
28
+ AlpacaPrompter(PromptStyle.chat.value),
29
+ tokenizer,
30
+ cfg.train_on_inputs,
31
+ cfg.sequence_len,
32
  )
src/axolotl/prompt_strategies/alpaca_instruct.py CHANGED
@@ -4,5 +4,8 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
4
 
5
  def load(tokenizer, cfg):
6
  return AlpacaPromptTokenizingStrategy(
7
- AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
8
  )
 
4
 
5
  def load(tokenizer, cfg):
6
  return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.instruct),
8
+ tokenizer,
9
+ cfg.train_on_inputs,
10
+ cfg.sequence_len,
11
  )
src/axolotl/prompt_strategies/creative_acr.py CHANGED
@@ -7,7 +7,9 @@ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
7
  class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
8
  def parse_instruction_fields(self, prompt) -> (str, str, str):
9
  question = prompt["instruction"]
10
- answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset
 
 
11
  return (
12
  question,
13
  "",
@@ -48,8 +50,12 @@ Answer: {answer}
48
  """
49
 
50
  def parse_instruction_fields(self, prompt) -> (str, str, str):
51
- scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
52
- critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
 
 
 
 
53
  evaluation = scores + critiques
54
  question = prompt["instruction"]
55
  answer = prompt["answer"]
@@ -76,13 +82,19 @@ Evaluation:
76
  """
77
 
78
  def parse_instruction_fields(self, prompt) -> (str, str, str):
79
- scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
80
- critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
 
 
 
 
81
  evaluation = scores + critiques
82
  question = prompt["instruction"]
83
  answer = prompt["answer"]
84
  return (
85
- self.user_prompt.format(question=question, answer=answer, evaluation=evaluation),
 
 
86
  "",
87
  prompt["revision"],
88
  )
 
7
  class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
8
  def parse_instruction_fields(self, prompt) -> (str, str, str):
9
  question = prompt["instruction"]
10
+ answer = prompt[
11
+ "revision"
12
+ ] # don't use prompt[answer], that's data we don't want in the dataset
13
  return (
14
  question,
15
  "",
 
50
  """
51
 
52
  def parse_instruction_fields(self, prompt) -> (str, str, str):
53
+ scores = yaml.dump(
54
+ prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
55
+ )
56
+ critiques = yaml.dump(
57
+ prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
58
+ )
59
  evaluation = scores + critiques
60
  question = prompt["instruction"]
61
  answer = prompt["answer"]
 
82
  """
83
 
84
  def parse_instruction_fields(self, prompt) -> (str, str, str):
85
+ scores = yaml.dump(
86
+ prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
87
+ )
88
+ critiques = yaml.dump(
89
+ prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
90
+ )
91
  evaluation = scores + critiques
92
  question = prompt["instruction"]
93
  answer = prompt["answer"]
94
  return (
95
+ self.user_prompt.format(
96
+ question=question, answer=answer, evaluation=evaluation
97
+ ),
98
  "",
99
  prompt["revision"],
100
  )
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -30,20 +30,34 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
30
  # this should include a bos token, no eos token, strip trailing "\n<START>"
31
  if message.endswith("\n<START>"):
32
  message = message[:-8]
33
- res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False)
 
 
 
 
34
  # everything from this is masked out from the labels
35
- labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
36
  elif role == "human":
37
  prefix = "<|user|>"
38
- res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True)
 
 
 
 
39
  # everything from this is masked out from the labels
40
- labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
41
  elif role == "bot":
42
  prefix = "<|model|>"
43
- res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
 
 
 
 
44
  # mask out the prefix token, rest is not masked out from labels
45
  # make sure we create the labels first, otherwise we get incorrect lengths
46
- labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
 
 
47
  else:
48
  logging.warning(f"unknown role in conversation: {role}")
49
  res = defaultdict(lambda: [])
@@ -51,8 +65,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
51
  input_len = len(input_ids)
52
  result["input_ids"][current_len : current_len + input_len] = input_ids
53
  result["attention_mask"][current_len : current_len + input_len] = [
54
- 1 if x != self.tokenizer.pad_token_id else 0
55
- for x in input_ids
56
  ]
57
  result["labels"][current_len : current_len + input_len] = labels
58
  current_len += input_len
@@ -74,10 +87,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
74
  result["input_ids"].append(self.tokenizer.eos_token_id)
75
  result["attention_mask"].append(1)
76
 
77
- if (
78
- result["input_ids"][0] == self.tokenizer.bos_token_id
79
- and strip_bos_token
80
- ):
81
  result["input_ids"] = result["input_ids"][1:]
82
  result["attention_mask"] = result["attention_mask"][1:]
83
 
 
30
  # this should include a bos token, no eos token, strip trailing "\n<START>"
31
  if message.endswith("\n<START>"):
32
  message = message[:-8]
33
+ res = self._tokenize(
34
+ prefix + "Persona: " + message.strip(),
35
+ add_eos_token=False,
36
+ strip_bos_token=False,
37
+ )
38
  # everything from this is masked out from the labels
39
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
40
  elif role == "human":
41
  prefix = "<|user|>"
42
+ res = self._tokenize(
43
+ prefix + " " + message.strip(),
44
+ add_eos_token=False,
45
+ strip_bos_token=True,
46
+ )
47
  # everything from this is masked out from the labels
48
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
49
  elif role == "bot":
50
  prefix = "<|model|>"
51
+ res = self._tokenize(
52
+ prefix + " " + message.strip(),
53
+ add_eos_token=True,
54
+ strip_bos_token=True,
55
+ )
56
  # mask out the prefix token, rest is not masked out from labels
57
  # make sure we create the labels first, otherwise we get incorrect lengths
58
+ labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [
59
+ *copy.deepcopy(res["input_ids"])
60
+ ][len(self.bot_prefix_token_ids) :]
61
  else:
62
  logging.warning(f"unknown role in conversation: {role}")
63
  res = defaultdict(lambda: [])
 
65
  input_len = len(input_ids)
66
  result["input_ids"][current_len : current_len + input_len] = input_ids
67
  result["attention_mask"][current_len : current_len + input_len] = [
68
+ 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
 
69
  ]
70
  result["labels"][current_len : current_len + input_len] = labels
71
  current_len += input_len
 
87
  result["input_ids"].append(self.tokenizer.eos_token_id)
88
  result["attention_mask"].append(1)
89
 
90
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
 
 
 
91
  result["input_ids"] = result["input_ids"][1:]
92
  result["attention_mask"] = result["attention_mask"][1:]
93
 
src/axolotl/prompt_tokenizers.py CHANGED
@@ -59,10 +59,14 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
59
  full_prompt = self._build_full_prompt(instruction, input, response)
60
  tokenized_full_prompt = self._tokenize(full_prompt)
61
  if not self.train_on_inputs:
62
- user_prompt = next(iter(self.prompter.build_prompt(
63
- instruction,
64
- input,
65
- )))
 
 
 
 
66
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
67
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
68
  # TODO this could be sped up using numpy array slicing
@@ -73,11 +77,15 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
73
  return tokenized_full_prompt
74
 
75
  def _build_full_prompt(self, instruction, input, response):
76
- return next(iter(self.prompter.build_prompt(
77
- instruction,
78
- input,
79
- response,
80
- )))
 
 
 
 
81
 
82
  def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
83
  result = self.tokenizer(
@@ -95,10 +103,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
95
  result["input_ids"].append(self.tokenizer.eos_token_id)
96
  result["attention_mask"].append(1)
97
 
98
- if (
99
- result["input_ids"][0] == self.tokenizer.bos_token_id
100
- and strip_bos_token
101
- ):
102
  result["input_ids"] = result["input_ids"][1:]
103
  result["attention_mask"] = result["attention_mask"][1:]
104
 
@@ -201,10 +206,14 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
201
  )
202
  tokenized_full_prompt = self._tokenize(full_prompt)
203
  if not self.train_on_inputs:
204
- user_prompt = next(iter(self.prompter.build_prompt(
205
- instruction,
206
- input,
207
- )))
 
 
 
 
208
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
209
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
210
  # TODO this could be sped up using numpy array slicing
@@ -215,13 +224,17 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
215
  return tokenized_full_prompt
216
 
217
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
218
- return next(iter(self.prompter.build_prompt(
219
- instruction,
220
- input,
221
- output,
222
- reflection,
223
- corrected,
224
- )))
 
 
 
 
225
 
226
  def _tokenize(self, prompt, add_eos_token=True):
227
  result = self.tokenizer(
@@ -265,21 +278,27 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
265
  user_token = self._get_user_token()
266
  assistant_token = self._get_assistant_token()
267
  try:
268
- for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
 
 
269
  if isinstance(part, tuple):
270
  if part[0] == "USER:":
271
  part = part[0] + part[1] if not user_token else part[1]
272
  # this is still the user query, we should
273
- res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
 
 
274
  if user_token:
275
  res["input_ids"] = [user_token, *res["input_ids"]]
276
  # everything from this is masked out from the labels
277
- labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
278
  elif part[0] == "ASSISTANT:":
279
  # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
280
  part = part[0] + part[1] if not assistant_token else part[1]
281
  # this should be the assistent response, should end with an eos token
282
- res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
 
 
283
  if assistant_token:
284
  res["input_ids"] = [assistant_token, *res["input_ids"]]
285
  # not masked out from labels
@@ -288,15 +307,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
288
  logging.warning("unhandled role: " + part[0])
289
  else:
290
  # this is only ever the first part, should include the bos token and the user query
291
- res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False)
 
 
292
  # everything from this is masked out from the labels
293
- labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
294
  input_ids = res["input_ids"]
295
  input_len = len(input_ids)
296
  result["input_ids"][current_len : current_len + input_len] = input_ids
297
  result["attention_mask"][current_len : current_len + input_len] = [
298
- 1 if x != self.tokenizer.pad_token_id else 0
299
- for x in input_ids
300
  ]
301
  result["labels"][current_len : current_len + input_len] = labels
302
  current_len += input_len
@@ -320,10 +340,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
320
  result["input_ids"].append(self.tokenizer.eos_token_id)
321
  result["attention_mask"].append(1)
322
 
323
- if (
324
- result["input_ids"][0] == self.tokenizer.bos_token_id
325
- and strip_bos_token
326
- ):
327
  result["input_ids"] = result["input_ids"][1:]
328
  result["attention_mask"] = result["attention_mask"][1:]
329
 
 
59
  full_prompt = self._build_full_prompt(instruction, input, response)
60
  tokenized_full_prompt = self._tokenize(full_prompt)
61
  if not self.train_on_inputs:
62
+ user_prompt = next(
63
+ iter(
64
+ self.prompter.build_prompt(
65
+ instruction,
66
+ input,
67
+ )
68
+ )
69
+ )
70
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
71
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
72
  # TODO this could be sped up using numpy array slicing
 
77
  return tokenized_full_prompt
78
 
79
  def _build_full_prompt(self, instruction, input, response):
80
+ return next(
81
+ iter(
82
+ self.prompter.build_prompt(
83
+ instruction,
84
+ input,
85
+ response,
86
+ )
87
+ )
88
+ )
89
 
90
  def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
91
  result = self.tokenizer(
 
103
  result["input_ids"].append(self.tokenizer.eos_token_id)
104
  result["attention_mask"].append(1)
105
 
106
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
 
 
 
107
  result["input_ids"] = result["input_ids"][1:]
108
  result["attention_mask"] = result["attention_mask"][1:]
109
 
 
206
  )
207
  tokenized_full_prompt = self._tokenize(full_prompt)
208
  if not self.train_on_inputs:
209
+ user_prompt = next(
210
+ iter(
211
+ self.prompter.build_prompt(
212
+ instruction,
213
+ input,
214
+ )
215
+ )
216
+ )
217
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
218
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
219
  # TODO this could be sped up using numpy array slicing
 
224
  return tokenized_full_prompt
225
 
226
  def _build_full_prompt(self, instruction, input, output, reflection, corrected):
227
+ return next(
228
+ iter(
229
+ self.prompter.build_prompt(
230
+ instruction,
231
+ input,
232
+ output,
233
+ reflection,
234
+ corrected,
235
+ )
236
+ )
237
+ )
238
 
239
  def _tokenize(self, prompt, add_eos_token=True):
240
  result = self.tokenizer(
 
278
  user_token = self._get_user_token()
279
  assistant_token = self._get_assistant_token()
280
  try:
281
+ for i, part in enumerate(
282
+ self.prompter.build_prompt(prompt["conversations"])
283
+ ):
284
  if isinstance(part, tuple):
285
  if part[0] == "USER:":
286
  part = part[0] + part[1] if not user_token else part[1]
287
  # this is still the user query, we should
288
+ res = self._tokenize(
289
+ part.strip(), add_eos_token=False, strip_bos_token=True
290
+ )
291
  if user_token:
292
  res["input_ids"] = [user_token, *res["input_ids"]]
293
  # everything from this is masked out from the labels
294
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
295
  elif part[0] == "ASSISTANT:":
296
  # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
297
  part = part[0] + part[1] if not assistant_token else part[1]
298
  # this should be the assistent response, should end with an eos token
299
+ res = self._tokenize(
300
+ part.strip(), add_eos_token=True, strip_bos_token=True
301
+ )
302
  if assistant_token:
303
  res["input_ids"] = [assistant_token, *res["input_ids"]]
304
  # not masked out from labels
 
307
  logging.warning("unhandled role: " + part[0])
308
  else:
309
  # this is only ever the first part, should include the bos token and the user query
310
+ res = self._tokenize(
311
+ part.strip(), add_eos_token=False, strip_bos_token=False
312
+ )
313
  # everything from this is masked out from the labels
314
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
315
  input_ids = res["input_ids"]
316
  input_len = len(input_ids)
317
  result["input_ids"][current_len : current_len + input_len] = input_ids
318
  result["attention_mask"][current_len : current_len + input_len] = [
319
+ 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
 
320
  ]
321
  result["labels"][current_len : current_len + input_len] = labels
322
  current_len += input_len
 
340
  result["input_ids"].append(self.tokenizer.eos_token_id)
341
  result["attention_mask"].append(1)
342
 
343
+ if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
 
 
 
344
  result["input_ids"] = result["input_ids"][1:]
345
  result["attention_mask"] = result["attention_mask"][1:]
346
 
src/axolotl/prompters.py CHANGED
@@ -23,12 +23,22 @@ class AlpacaPrompter:
23
 
24
  def match_prompt_style(self):
25
  if self.prompt_style == PromptStyle.instruct.value:
26
- self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
27
- self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
 
 
 
 
 
 
28
  self.response_split = "### Response:"
29
  if self.prompt_style == PromptStyle.chat.value:
30
- self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
31
- self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
 
 
 
 
32
  self.response_split = "ASSISTANT:"
33
 
34
  def build_prompt(
@@ -55,12 +65,15 @@ class UnpromptedPrompter(AlpacaPrompter):
55
  system_prompt = ""
56
  system_no_input_prompt = ""
57
 
 
58
  class JeopardyPrompter(AlpacaPrompter):
59
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
60
 
61
 
62
  class MultipleChoiceExplainPrompter(AlpacaPrompter):
63
- system_prompt = "Choose the answer that best answers the question. Explain your reasoning."
 
 
64
 
65
 
66
  class MultipleChoiceConcisePrompter(AlpacaPrompter):
@@ -68,11 +81,15 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
68
 
69
 
70
  class SummarizeTLDRPrompter(AlpacaPrompter):
71
- prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
 
 
72
 
73
 
74
  class CompletionPrompter(AlpacaPrompter):
75
- def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
 
 
76
  yield instruction
77
 
78
  def get_response(self, output: str) -> str:
@@ -91,7 +108,9 @@ class ReflectAlpacaPrompter:
91
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
92
  system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
93
 
94
- prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
 
 
95
  prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
96
  agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
97
  response_split = "### Response:"
@@ -102,14 +121,26 @@ class ReflectAlpacaPrompter:
102
 
103
  def match_prompt_style(self):
104
  if self.prompt_style == PromptStyle.instruct.value:
105
- self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
106
- self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n"
 
 
 
 
 
 
107
  self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
108
  self.response_split = "### Final Response:"
109
  if self.prompt_style == PromptStyle.chat.value:
110
- self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
111
- self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
112
- self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
 
 
 
 
 
 
113
  self.response_split = "ASSISTANT:"
114
 
115
  def build_prompt(
@@ -167,7 +198,7 @@ class Conversation:
167
  yield (role + ":", " " + message)
168
  else:
169
  logging.warning("role with empty message: " + role)
170
- yield (role + ":", )
171
 
172
  def copy(self):
173
  return Conversation(
@@ -199,7 +230,9 @@ conv_vicuna_v1_1 = Conversation(
199
  class ShareGPTPrompter:
200
  def __init__(self, prompt_style=None):
201
  if prompt_style != PromptStyle.chat.value:
202
- raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})")
 
 
203
 
204
  # def match_prompt_style(self):
205
  # if self.prompt_style == PromptStyle.chat.value:
 
23
 
24
  def match_prompt_style(self):
25
  if self.prompt_style == PromptStyle.instruct.value:
26
+ self.prompt_input = (
27
+ self.system_prompt
28
+ + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
29
+ )
30
+ self.prompt_no_input = (
31
+ self.system_no_input_prompt
32
+ + "### Instruction:\n{instruction}\n\n### Response:\n"
33
+ )
34
  self.response_split = "### Response:"
35
  if self.prompt_style == PromptStyle.chat.value:
36
+ self.prompt_input = (
37
+ self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
38
+ )
39
+ self.prompt_no_input = (
40
+ self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
41
+ )
42
  self.response_split = "ASSISTANT:"
43
 
44
  def build_prompt(
 
65
  system_prompt = ""
66
  system_no_input_prompt = ""
67
 
68
+
69
  class JeopardyPrompter(AlpacaPrompter):
70
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
71
 
72
 
73
  class MultipleChoiceExplainPrompter(AlpacaPrompter):
74
+ system_prompt = (
75
+ "Choose the answer that best answers the question. Explain your reasoning."
76
+ )
77
 
78
 
79
  class MultipleChoiceConcisePrompter(AlpacaPrompter):
 
81
 
82
 
83
  class SummarizeTLDRPrompter(AlpacaPrompter):
84
+ prompt_no_input = (
85
+ "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
86
+ )
87
 
88
 
89
  class CompletionPrompter(AlpacaPrompter):
90
+ def build_prompt(
91
+ self, instruction: str, input=None, output=None
92
+ ) -> Generator[str, None, None]:
93
  yield instruction
94
 
95
  def get_response(self, output: str) -> str:
 
108
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
109
  system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
110
 
111
+ prompt_input = (
112
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
113
+ )
114
  prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n"
115
  agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
116
  response_split = "### Response:"
 
121
 
122
  def match_prompt_style(self):
123
  if self.prompt_style == PromptStyle.instruct.value:
124
+ self.prompt_input = (
125
+ self.system_prompt
126
+ + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
127
+ )
128
+ self.prompt_no_input = (
129
+ self.system_no_input_prompt
130
+ + "### Instruction:\n{instruction}\n\n### Response:\n"
131
+ )
132
  self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
133
  self.response_split = "### Final Response:"
134
  if self.prompt_style == PromptStyle.chat.value:
135
+ self.prompt_input = (
136
+ self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
137
+ )
138
+ self.prompt_no_input = (
139
+ self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
140
+ )
141
+ self.agent_label = (
142
+ "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:"
143
+ )
144
  self.response_split = "ASSISTANT:"
145
 
146
  def build_prompt(
 
198
  yield (role + ":", " " + message)
199
  else:
200
  logging.warning("role with empty message: " + role)
201
+ yield (role + ":",)
202
 
203
  def copy(self):
204
  return Conversation(
 
230
  class ShareGPTPrompter:
231
  def __init__(self, prompt_style=None):
232
  if prompt_style != PromptStyle.chat.value:
233
+ raise Exception(
234
+ f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
235
+ )
236
 
237
  # def match_prompt_style(self):
238
  # if self.prompt_style == PromptStyle.chat.value:
src/axolotl/utils/data.py CHANGED
@@ -7,7 +7,8 @@ from datasets import (
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
- concatenate_datasets, DatasetDict,
 
11
  )
12
  from huggingface_hub import hf_hub_download
13
  from transformers import PreTrainedTokenizerBase
@@ -33,11 +34,14 @@ from axolotl.prompters import (
33
  JeopardyPrompter,
34
  CompletionPrompter,
35
  MultipleChoiceExplainPrompter,
36
- SummarizeTLDRPrompter, MultipleChoiceConcisePrompter,
 
37
  )
38
 
39
 
40
- def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
 
 
41
  tokenizer_name = tokenizer.__class__.__name__
42
  ds_hash = str(
43
  md5(
@@ -45,7 +49,8 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
45
  str(cfg.sequence_len)
46
  + "@"
47
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
48
- + "|" + tokenizer_name
 
49
  ).encode("utf-8")
50
  ).hexdigest()
51
  )
@@ -57,7 +62,9 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
57
  dataset = None
58
  try:
59
  if cfg.push_dataset_to_hub:
60
- dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
 
 
61
  dataset = dataset["train"]
62
  except:
63
  pass
@@ -88,7 +95,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
88
  )
89
  elif ds_from_hub:
90
  if d.data_files:
91
- ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True)
 
 
 
 
 
92
  else:
93
  ds = load_dataset(d.path, streaming=False, use_auth_token=True)
94
  else:
@@ -100,49 +112,65 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
100
  raise Exception("unhandled dataset load")
101
  # support for using a subset of the data
102
  if d.shards:
103
- ds = ds.shuffle(seed=42)["train"].shard(
104
- num_shards=cfg.shards, index=0
105
- )
106
  d_type = d.type
107
  d_type_split = d_type.split(":")
108
  d_base_type = d_type_split[0]
109
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
110
- if (ds_strategy := load(d.type, tokenizer, cfg)):
111
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
112
  datasets.append(ds_wrapper)
113
  elif d_base_type == "alpaca":
114
  ds_strategy = AlpacaPromptTokenizingStrategy(
115
- AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
116
  )
117
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
118
  datasets.append(ds_wrapper)
119
  elif d_base_type == "explainchoice":
120
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
121
- MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
122
  )
123
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
124
  datasets.append(ds_wrapper)
125
  elif d_base_type == "concisechoice":
126
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
127
- MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
128
  )
129
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
130
  datasets.append(ds_wrapper)
131
  elif d_base_type == "summarizetldr":
132
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
133
- SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
134
  )
135
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
136
  datasets.append(ds_wrapper)
137
  elif d_base_type == "jeopardy":
138
  ds_strategy = JeopardyPromptTokenizingStrategy(
139
- JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
140
  )
141
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
142
  datasets.append(ds_wrapper)
143
  elif d_base_type == "oasst":
144
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
145
- AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
146
  )
147
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
148
  datasets.append(ds_wrapper)
@@ -166,7 +194,10 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
166
  datasets.append(ds_wrapper)
167
  elif d_base_type == "sharegpt":
168
  ds_strategy = ShareGPTPromptTokenizingStrategy(
169
- ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
170
  )
171
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
172
  datasets.append(ds_wrapper)
@@ -196,12 +227,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
196
  logging.info(
197
  f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
198
  )
199
- dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
 
 
200
 
201
  return dataset
202
 
203
 
204
- def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
 
 
205
  max_packed_sequence_len = (
206
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
207
  )
@@ -221,7 +256,8 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
221
  + str(max_packed_sequence_len)
222
  + seed
223
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
224
- + "|" + tokenizer_name
 
225
  ).encode("utf-8")
226
  ).hexdigest()
227
  )
@@ -237,7 +273,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
237
  logging.info(
238
  f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
239
  )
240
- dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True)
 
 
241
  dataset = dataset["train"]
242
  except:
243
  pass
@@ -254,7 +292,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
254
  logging.info(
255
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
256
  )
257
- dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
 
 
258
  else:
259
  dataset = load_tokenized_prepared_datasets(
260
  tokenizer, cfg, default_dataset_prepared_path
@@ -279,9 +319,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
279
  d
280
  for d in dataset
281
  if len(d["input_ids"]) < cfg.sequence_len
282
- and len(d["input_ids"]) > 0
283
- and len(d["input_ids"]) == len(d["attention_mask"])
284
- and len(d["input_ids"]) == len(d["labels"])
285
  ]
286
  )
287
 
@@ -294,7 +334,9 @@ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_datas
294
  logging.info(
295
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
296
  )
297
- dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
 
 
298
  else:
299
  dataset = load_tokenized_prepared_datasets(
300
  tokenizer, cfg, default_dataset_prepared_path
 
7
  load_dataset,
8
  IterableDataset,
9
  Dataset,
10
+ concatenate_datasets,
11
+ DatasetDict,
12
  )
13
  from huggingface_hub import hf_hub_download
14
  from transformers import PreTrainedTokenizerBase
 
34
  JeopardyPrompter,
35
  CompletionPrompter,
36
  MultipleChoiceExplainPrompter,
37
+ SummarizeTLDRPrompter,
38
+ MultipleChoiceConcisePrompter,
39
  )
40
 
41
 
42
+ def load_tokenized_prepared_datasets(
43
+ tokenizer, cfg, default_dataset_prepared_path
44
+ ) -> DatasetDict:
45
  tokenizer_name = tokenizer.__class__.__name__
46
  ds_hash = str(
47
  md5(
 
49
  str(cfg.sequence_len)
50
  + "@"
51
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
52
+ + "|"
53
+ + tokenizer_name
54
  ).encode("utf-8")
55
  ).hexdigest()
56
  )
 
62
  dataset = None
63
  try:
64
  if cfg.push_dataset_to_hub:
65
+ dataset = load_dataset(
66
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
67
+ )
68
  dataset = dataset["train"]
69
  except:
70
  pass
 
95
  )
96
  elif ds_from_hub:
97
  if d.data_files:
98
+ ds = load_dataset(
99
+ d.path,
100
+ streaming=False,
101
+ data_files=d.data_files,
102
+ use_auth_token=True,
103
+ )
104
  else:
105
  ds = load_dataset(d.path, streaming=False, use_auth_token=True)
106
  else:
 
112
  raise Exception("unhandled dataset load")
113
  # support for using a subset of the data
114
  if d.shards:
115
+ ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
 
 
116
  d_type = d.type
117
  d_type_split = d_type.split(":")
118
  d_base_type = d_type_split[0]
119
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
120
+ if ds_strategy := load(d.type, tokenizer, cfg):
121
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
122
  datasets.append(ds_wrapper)
123
  elif d_base_type == "alpaca":
124
  ds_strategy = AlpacaPromptTokenizingStrategy(
125
+ AlpacaPrompter(d_prompt_style),
126
+ tokenizer,
127
+ cfg.train_on_inputs,
128
+ cfg.sequence_len,
129
  )
130
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
131
  datasets.append(ds_wrapper)
132
  elif d_base_type == "explainchoice":
133
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
134
+ MultipleChoiceExplainPrompter(d_prompt_style),
135
+ tokenizer,
136
+ cfg.train_on_inputs,
137
+ cfg.sequence_len,
138
  )
139
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
140
  datasets.append(ds_wrapper)
141
  elif d_base_type == "concisechoice":
142
  ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
143
+ MultipleChoiceConcisePrompter(d_prompt_style),
144
+ tokenizer,
145
+ cfg.train_on_inputs,
146
+ cfg.sequence_len,
147
  )
148
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
149
  datasets.append(ds_wrapper)
150
  elif d_base_type == "summarizetldr":
151
  ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
152
+ SummarizeTLDRPrompter(d_prompt_style),
153
+ tokenizer,
154
+ cfg.train_on_inputs,
155
+ cfg.sequence_len,
156
  )
157
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
158
  datasets.append(ds_wrapper)
159
  elif d_base_type == "jeopardy":
160
  ds_strategy = JeopardyPromptTokenizingStrategy(
161
+ JeopardyPrompter(d_prompt_style),
162
+ tokenizer,
163
+ cfg.train_on_inputs,
164
+ cfg.sequence_len,
165
  )
166
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
167
  datasets.append(ds_wrapper)
168
  elif d_base_type == "oasst":
169
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
170
+ AlpacaPrompter(d_prompt_style),
171
+ tokenizer,
172
+ cfg.train_on_inputs,
173
+ cfg.sequence_len,
174
  )
175
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
176
  datasets.append(ds_wrapper)
 
194
  datasets.append(ds_wrapper)
195
  elif d_base_type == "sharegpt":
196
  ds_strategy = ShareGPTPromptTokenizingStrategy(
197
+ ShareGPTPrompter(d_prompt_style),
198
+ tokenizer,
199
+ cfg.train_on_inputs,
200
+ cfg.sequence_len,
201
  )
202
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
203
  datasets.append(ds_wrapper)
 
227
  logging.info(
228
  f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
229
  )
230
+ dataset.push_to_hub(
231
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
232
+ )
233
 
234
  return dataset
235
 
236
 
237
+ def load_prepare_datasets(
238
+ tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
239
+ ) -> (Dataset, Dataset):
240
  max_packed_sequence_len = (
241
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
242
  )
 
256
  + str(max_packed_sequence_len)
257
  + seed
258
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
259
+ + "|"
260
+ + tokenizer_name
261
  ).encode("utf-8")
262
  ).hexdigest()
263
  )
 
273
  logging.info(
274
  f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
275
  )
276
+ dataset = load_dataset(
277
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
278
+ )
279
  dataset = dataset["train"]
280
  except:
281
  pass
 
292
  logging.info(
293
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
294
  )
295
+ dataset.push_to_hub(
296
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
297
+ )
298
  else:
299
  dataset = load_tokenized_prepared_datasets(
300
  tokenizer, cfg, default_dataset_prepared_path
 
319
  d
320
  for d in dataset
321
  if len(d["input_ids"]) < cfg.sequence_len
322
+ and len(d["input_ids"]) > 0
323
+ and len(d["input_ids"]) == len(d["attention_mask"])
324
+ and len(d["input_ids"]) == len(d["labels"])
325
  ]
326
  )
327
 
 
334
  logging.info(
335
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
336
  )
337
+ dataset.push_to_hub(
338
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
339
+ )
340
  else:
341
  dataset = load_tokenized_prepared_datasets(
342
  tokenizer, cfg, default_dataset_prepared_path
src/axolotl/utils/models.py CHANGED
@@ -11,7 +11,8 @@ from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
- AutoConfig, BitsAndBytesConfig,
 
15
  )
16
 
17
  try:
@@ -244,7 +245,9 @@ def load_model(
244
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
245
  model.resize_token_embeddings(embeddings_len)
246
 
247
- if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
 
 
248
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
249
  model = prepare_model_for_int8_training(model)
250
 
@@ -265,7 +268,11 @@ def load_model(
265
  m.scales = m.scales.half()
266
  m.bias = m.bias.half()
267
 
268
- if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 and cfg.load_4bit:
 
 
 
 
269
  # llama is PROBABLY model parallelizable, but the default isn't that it is
270
  # so let's only set it for the 4bit, see
271
  # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
 
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
+ AutoConfig,
15
+ BitsAndBytesConfig,
16
  )
17
 
18
  try:
 
245
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
246
  model.resize_token_embeddings(embeddings_len)
247
 
248
+ if (
249
+ (cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
250
+ ) and not cfg.load_4bit:
251
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
252
  model = prepare_model_for_int8_training(model)
253
 
 
268
  m.scales = m.scales.half()
269
  m.bias = m.bias.half()
270
 
271
+ if (
272
+ torch.cuda.device_count() > 1
273
+ and int(os.getenv("WORLD_SIZE", "1")) > 1
274
+ and cfg.load_4bit
275
+ ):
276
  # llama is PROBABLY model parallelizable, but the default isn't that it is
277
  # so let's only set it for the 4bit, see
278
  # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
src/axolotl/utils/trainer.py CHANGED
@@ -17,10 +17,12 @@ from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
  class OneCycleLRSchedulerTrainer(Trainer):
20
- def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
21
- optimizer=self.optimizer if optimizer is None else optimizer
22
- num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
23
- num_training_steps=num_training_steps
 
 
24
  pct_start = num_warmup_steps / num_training_steps
25
 
26
  self.lr_scheduler = OneCycleLR(
@@ -203,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
203
  )
204
  callbacks.append(early_stop_cb)
205
 
206
- if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0
207
  callbacks.append(SavePeftModelCallback)
208
 
209
  data_collator_kwargs = {
@@ -214,7 +216,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
214
  else:
215
  data_collator_kwargs["pad_to_multiple_of"] = 8
216
 
217
- trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
 
 
 
 
218
  trainer = trainer_cls(
219
  model=model,
220
  train_dataset=train_dataset,
 
17
 
18
 
19
  class OneCycleLRSchedulerTrainer(Trainer):
20
+ def create_scheduler(
21
+ self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
22
+ ):
23
+ optimizer = self.optimizer if optimizer is None else optimizer
24
+ num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
25
+ num_training_steps = num_training_steps
26
  pct_start = num_warmup_steps / num_training_steps
27
 
28
  self.lr_scheduler = OneCycleLR(
 
205
  )
206
  callbacks.append(early_stop_cb)
207
 
208
+ if cfg.local_rank == 0 and cfg.adapter == "lora": # only save in rank 0
209
  callbacks.append(SavePeftModelCallback)
210
 
211
  data_collator_kwargs = {
 
216
  else:
217
  data_collator_kwargs["pad_to_multiple_of"] = 8
218
 
219
+ trainer_cls = (
220
+ OneCycleLRSchedulerTrainer
221
+ if cfg.lr_scheduler == "one_cycle" and cfg.fsdp
222
+ else transformers.Trainer
223
+ )
224
  trainer = trainer_cls(
225
  model=model,
226
  train_dataset=train_dataset,