winglian commited on
Commit
4ea9a66
·
1 Parent(s): 1d5ab84

tokenization fixes

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
2
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
3
+
4
+
5
+ def load(tokenizer, cfg):
6
+ return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len
8
+ )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -38,14 +38,14 @@ class PromptTokenizingStrategy(abc.ABC):
38
  @functools.cache
39
  def _get_user_token(self):
40
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
41
- if type(id_or_ids, (int,)):
42
  return id_or_ids
43
  return False
44
 
45
  @functools.cache
46
  def _get_assistant_token(self):
47
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
48
- if type(id_or_ids, (int,)):
49
  return id_or_ids
50
  return False
51
 
@@ -272,15 +272,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
272
  # this is still the user query, we should
273
  res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
274
  if user_token:
275
- res = [user_token, *res]
276
  # everything from this is masked out from the labels
277
  labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
278
  elif part[0] == "ASSISTANT:":
 
279
  part = part[0] + part[1] if not assistant_token else part[1]
280
  # this should be the assistent response, should end with an eos token
281
  res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
282
  if assistant_token:
283
- res = [assistant_token, *res]
284
  # not masked out from labels
285
  labels = copy.deepcopy(res["input_ids"])
286
  else:
 
38
  @functools.cache
39
  def _get_user_token(self):
40
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
41
+ if isinstance(id_or_ids, (int,)):
42
  return id_or_ids
43
  return False
44
 
45
  @functools.cache
46
  def _get_assistant_token(self):
47
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
48
+ if isinstance(id_or_ids, (int,)):
49
  return id_or_ids
50
  return False
51
 
 
272
  # this is still the user query, we should
273
  res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True)
274
  if user_token:
275
+ res["input_ids"] = [user_token, *res["input_ids"]]
276
  # everything from this is masked out from the labels
277
  labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
278
  elif part[0] == "ASSISTANT:":
279
+ # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
280
  part = part[0] + part[1] if not assistant_token else part[1]
281
  # this should be the assistent response, should end with an eos token
282
  res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True)
283
  if assistant_token:
284
+ res["input_ids"] = [assistant_token, *res["input_ids"]]
285
  # not masked out from labels
286
  labels = copy.deepcopy(res["input_ids"])
287
  else:
src/axolotl/utils/data.py CHANGED
@@ -12,6 +12,7 @@ from datasets import (
12
  from huggingface_hub import hf_hub_download
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
 
15
  from axolotl.prompt_tokenizers import (
16
  AlpacaPromptTokenizingStrategy,
17
  GPTeacherPromptTokenizingStrategy,
@@ -94,10 +95,13 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
94
  if not ds:
95
  raise Exception("unhandled dataset load")
96
  d_type = d.type
97
- d_type_split = d.type.split(":")
98
  d_base_type = d_type_split[0]
99
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
100
- if d_base_type == "alpaca":
 
 
 
101
  ds_strategy = AlpacaPromptTokenizingStrategy(
102
  AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
103
  )
 
12
  from huggingface_hub import hf_hub_download
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
15
+ from axolotl.prompt_strategies import load
16
  from axolotl.prompt_tokenizers import (
17
  AlpacaPromptTokenizingStrategy,
18
  GPTeacherPromptTokenizingStrategy,
 
95
  if not ds:
96
  raise Exception("unhandled dataset load")
97
  d_type = d.type
98
+ d_type_split = d_type.split(":")
99
  d_base_type = d_type_split[0]
100
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
101
+ if (ds_strategy := load(d.type, tokenizer, cfg)):
102
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
103
+ datasets.append(ds_wrapper)
104
+ elif d_base_type == "alpaca":
105
  ds_strategy = AlpacaPromptTokenizingStrategy(
106
  AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len
107
  )
src/axolotl/utils/models.py CHANGED
@@ -220,7 +220,7 @@ def load_model(
220
  for k, v in cfg.special_tokens.items():
221
  tokenizer.add_special_tokens({k: v})
222
  if cfg.tokens:
223
- tokenizer.add_tokens(cfg.tokens)
224
 
225
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
226
  model.resize_token_embeddings(embeddings_len)
 
220
  for k, v in cfg.special_tokens.items():
221
  tokenizer.add_special_tokens({k: v})
222
  if cfg.tokens:
223
+ tokenizer.add_tokens(list(cfg.tokens))
224
 
225
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
226
  model.resize_token_embeddings(embeddings_len)