winglian commited on
Commit
97d3776
·
unverified ·
1 Parent(s): 2844eb2

split completion text to sequence_len (#616)

Browse files
src/axolotl/datasets.py CHANGED
@@ -38,10 +38,15 @@ class TokenizedPromptDataset(Dataset):
38
  def process(self, dataset):
39
  features = dataset.features.keys()
40
  num_proc = min(64, os.cpu_count())
 
 
 
 
41
  return dataset.map(
42
  self.prompt_tokenizer.tokenize_prompt,
43
  num_proc=num_proc,
44
  remove_columns=features,
 
45
  )
46
 
47
 
 
38
  def process(self, dataset):
39
  features = dataset.features.keys()
40
  num_proc = min(64, os.cpu_count())
41
+ map_kwargs = {}
42
+ if self.prompt_tokenizer.supports_batched:
43
+ map_kwargs["batched"] = True
44
+ map_kwargs["batch_size"] = 100
45
  return dataset.map(
46
  self.prompt_tokenizer.tokenize_prompt,
47
  num_proc=num_proc,
48
  remove_columns=features,
49
+ **map_kwargs,
50
  )
51
 
52
 
src/axolotl/prompt_strategies/completion.py CHANGED
@@ -1,10 +1,81 @@
1
  """
2
  Basic completion text
3
  """
4
- from typing import Any, Dict, Optional
 
5
 
6
- from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
7
- from axolotl.prompters import CompletionPrompter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
13
  tokenizer,
14
  cfg.train_on_inputs,
15
  cfg.sequence_len,
 
16
  )
17
  if ds_cfg and "field" in ds_cfg:
18
  strat.field = ds_cfg["field"]
 
1
  """
2
  Basic completion text
3
  """
4
+ from collections import defaultdict
5
+ from typing import Any, Dict, Generator, Optional, Tuple
6
 
7
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
8
+
9
+
10
+ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
11
+ """
12
+ Tokenizing strategy for Completion prompts.
13
+ """
14
+
15
+ _field: str = "text"
16
+
17
+ def __init__(self, *args, max_length=None, **kwargs):
18
+ super().__init__(*args, **kwargs)
19
+ if max_length is not None:
20
+ self.max_length = max_length
21
+
22
+ @property
23
+ def supports_batched(self):
24
+ return True
25
+
26
+ @property
27
+ def field(self) -> str:
28
+ return self._field
29
+
30
+ @field.setter
31
+ def field(self, new_field: str):
32
+ self._field = new_field
33
+
34
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
35
+ return (
36
+ prompt[self.field],
37
+ "",
38
+ "",
39
+ )
40
+
41
+ def tokenize_prompt(self, prompt):
42
+ res = defaultdict(lambda: [])
43
+ feature_names = list(prompt.keys())
44
+ for row in zip(*prompt.values()):
45
+ prompt_row = dict(zip(feature_names, row))
46
+ (
47
+ instruction,
48
+ _,
49
+ _,
50
+ ) = self.parse_instruction_fields(prompt_row)
51
+
52
+ full_prompt = self._build_full_prompt(instruction, None, None)
53
+ tokenized_full_prompt = self._tokenize(full_prompt)
54
+
55
+ for key, val in tokenized_full_prompt.items():
56
+ for i in range(0, len(val), self.sequence_len):
57
+ res[key].append(val[i : i + self.sequence_len])
58
+
59
+ return dict(res)
60
+
61
+ def _build_full_prompt(
62
+ self, instruction, input, response
63
+ ): # pylint: disable=redefined-builtin
64
+ return next(iter(self.prompter.build_prompt(instruction, input, response)))
65
+
66
+
67
+ class CompletionPrompter:
68
+ """
69
+ Prompter for completion
70
+ """
71
+
72
+ def build_prompt(
73
+ self,
74
+ instruction: str,
75
+ input=None, # pylint: disable=redefined-builtin, unused-argument
76
+ output=None, # pylint: disable=unused-argument
77
+ ) -> Generator[str, None, None]:
78
+ yield instruction
79
 
80
 
81
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
 
84
  tokenizer,
85
  cfg.train_on_inputs,
86
  cfg.sequence_len,
87
+ max_length=cfg.sequence_len * 64,
88
  )
89
  if ds_cfg and "field" in ds_cfg:
90
  strat.field = ds_cfg["field"]
src/axolotl/prompt_tokenizers.py CHANGED
@@ -41,11 +41,16 @@ class PromptTokenizingStrategy(abc.ABC):
41
  self.tokenizer: PreTrainedTokenizer = tokenizer
42
  self.train_on_inputs = train_on_inputs
43
  self.sequence_len = sequence_len
 
44
 
45
  @abc.abstractmethod
46
  def tokenize_prompt(self, prompt):
47
  pass
48
 
 
 
 
 
49
  @functools.lru_cache(maxsize=128)
50
  def _get_user_token(self):
51
  try:
@@ -77,7 +82,7 @@ class PromptTokenizingStrategy(abc.ABC):
77
  result = self.tokenizer(
78
  prompt,
79
  truncation=True,
80
- max_length=self.sequence_len,
81
  padding=False,
82
  return_tensors=None,
83
  )
@@ -86,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
86
  if (
87
  len(result["input_ids"]) > 0
88
  and result["input_ids"][-1] != self.tokenizer.eos_token_id
89
- and len(result["input_ids"]) < self.sequence_len
90
  and add_eos_token
91
  ):
92
  result["input_ids"].append(self.tokenizer.eos_token_id)
@@ -247,46 +252,6 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
247
  )
248
 
249
 
250
- class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
251
- """
252
- Tokenizing strategy for Completion prompts.
253
- """
254
-
255
- _field: str = "text"
256
-
257
- @property
258
- def field(self) -> str:
259
- return self._field
260
-
261
- @field.setter
262
- def field(self, new_field: str):
263
- self._field = new_field
264
-
265
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
266
- return (
267
- prompt[self.field],
268
- "",
269
- "",
270
- )
271
-
272
- def tokenize_prompt(self, prompt):
273
- (
274
- instruction,
275
- _,
276
- _,
277
- ) = self.parse_instruction_fields(prompt)
278
-
279
- full_prompt = self._build_full_prompt(instruction, None, None)
280
- tokenized_full_prompt = self._tokenize(full_prompt)
281
-
282
- return tokenized_full_prompt
283
-
284
- def _build_full_prompt(
285
- self, instruction, input, response
286
- ): # pylint: disable=redefined-builtin
287
- return next(iter(self.prompter.build_prompt(instruction, input, response)))
288
-
289
-
290
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
291
  """
292
  Tokenizing strategy for Reflection prompts.
 
41
  self.tokenizer: PreTrainedTokenizer = tokenizer
42
  self.train_on_inputs = train_on_inputs
43
  self.sequence_len = sequence_len
44
+ self.max_length = sequence_len
45
 
46
  @abc.abstractmethod
47
  def tokenize_prompt(self, prompt):
48
  pass
49
 
50
+ @property
51
+ def supports_batched(self):
52
+ return False
53
+
54
  @functools.lru_cache(maxsize=128)
55
  def _get_user_token(self):
56
  try:
 
82
  result = self.tokenizer(
83
  prompt,
84
  truncation=True,
85
+ max_length=self.max_length,
86
  padding=False,
87
  return_tensors=None,
88
  )
 
91
  if (
92
  len(result["input_ids"]) > 0
93
  and result["input_ids"][-1] != self.tokenizer.eos_token_id
94
+ and len(result["input_ids"]) < self.max_length
95
  and add_eos_token
96
  ):
97
  result["input_ids"].append(self.tokenizer.eos_token_id)
 
252
  )
253
 
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
256
  """
257
  Tokenizing strategy for Reflection prompts.
src/axolotl/prompters.py CHANGED
@@ -135,20 +135,6 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
135
  self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
136
 
137
 
138
- class CompletionPrompter:
139
- """
140
- Prompter for completion
141
- """
142
-
143
- def build_prompt(
144
- self,
145
- instruction: str,
146
- input=None, # pylint: disable=redefined-builtin, unused-argument
147
- output=None, # pylint: disable=unused-argument
148
- ) -> Generator[str, None, None]:
149
- yield instruction
150
-
151
-
152
  class GPTeacherPrompter(AlpacaPrompter):
153
  """
154
  Prompter for GPTeacher
 
135
  self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  class GPTeacherPrompter(AlpacaPrompter):
139
  """
140
  Prompter for GPTeacher