Nanobit commited on
Commit
25eeeeb
·
1 Parent(s): cfcc549

Fix sharegpt prompt

Browse files
src/axolotl/prompt_tokenizers.py CHANGED
@@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
371
  ]
372
  # not masked out from labels
373
  labels = copy.deepcopy(res["input_ids"])
 
 
 
 
 
 
 
 
374
  else:
375
  logging.warning(f"unhandled role: {part[0]}")
376
- else:
377
- # this is only ever the first part, should include the bos token and the user query
378
- res = self._tokenize(
379
- part.strip(), add_eos_token=False, strip_bos_token=False
380
- )
381
- # everything from this is masked out from the labels
382
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
383
 
384
  # pylint: disable=duplicate-code
385
  result, current_len = parse_tokenized_to_result(
 
371
  ]
372
  # not masked out from labels
373
  labels = copy.deepcopy(res["input_ids"])
374
+ elif part[0] == "SYSTEM:":
375
+ part = part[1] # Ignore the system role from preamble
376
+ # this is only ever the first part, should include the bos token and the user query
377
+ res = self._tokenize(
378
+ part.strip(), add_eos_token=False, strip_bos_token=False
379
+ )
380
+ # everything from this is masked out from the labels
381
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
382
  else:
383
  logging.warning(f"unhandled role: {part[0]}")
 
 
 
 
 
 
 
384
 
385
  # pylint: disable=duplicate-code
386
  result, current_len = parse_tokenized_to_result(
src/axolotl/prompters.py CHANGED
@@ -3,7 +3,7 @@
3
  import dataclasses
4
  import logging
5
  from enum import Enum, auto
6
- from typing import Generator, List, Optional, Union
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
@@ -235,16 +235,16 @@ class Conversation:
235
  sep: str = "###"
236
  sep2: Optional[str] = None
237
 
238
- def get_prompt(self) -> Generator[str, None, None]:
239
  # seps = [self.sep, self.sep2]
240
  preamble = self.system + self.sep
241
- yield preamble
242
  for _, (role, message) in enumerate(self.messages):
243
  if message:
244
- yield role + ":" + " " + message
245
  else:
246
  logging.warning(f"role with empty message: {role}")
247
- yield role + ":"
248
 
249
  def copy(self):
250
  return Conversation(
 
3
  import dataclasses
4
  import logging
5
  from enum import Enum, auto
6
+ from typing import Generator, List, Optional, Tuple, Union
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
 
235
  sep: str = "###"
236
  sep2: Optional[str] = None
237
 
238
+ def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
239
  # seps = [self.sep, self.sep2]
240
  preamble = self.system + self.sep
241
+ yield ("SYSTEM:", preamble)
242
  for _, (role, message) in enumerate(self.messages):
243
  if message:
244
+ yield (role + ":", " " + message)
245
  else:
246
  logging.warning(f"role with empty message: {role}")
247
+ yield (role + ":", "")
248
 
249
  def copy(self):
250
  return Conversation(