split completion text to sequence_len (#616)
Browse files
src/axolotl/datasets.py
CHANGED
@@ -38,10 +38,15 @@ class TokenizedPromptDataset(Dataset):
|
|
38 |
def process(self, dataset):
|
39 |
features = dataset.features.keys()
|
40 |
num_proc = min(64, os.cpu_count())
|
|
|
|
|
|
|
|
|
41 |
return dataset.map(
|
42 |
self.prompt_tokenizer.tokenize_prompt,
|
43 |
num_proc=num_proc,
|
44 |
remove_columns=features,
|
|
|
45 |
)
|
46 |
|
47 |
|
|
|
38 |
def process(self, dataset):
|
39 |
features = dataset.features.keys()
|
40 |
num_proc = min(64, os.cpu_count())
|
41 |
+
map_kwargs = {}
|
42 |
+
if self.prompt_tokenizer.supports_batched:
|
43 |
+
map_kwargs["batched"] = True
|
44 |
+
map_kwargs["batch_size"] = 100
|
45 |
return dataset.map(
|
46 |
self.prompt_tokenizer.tokenize_prompt,
|
47 |
num_proc=num_proc,
|
48 |
remove_columns=features,
|
49 |
+
**map_kwargs,
|
50 |
)
|
51 |
|
52 |
|
src/axolotl/prompt_strategies/completion.py
CHANGED
@@ -1,10 +1,81 @@
|
|
1 |
"""
|
2 |
Basic completion text
|
3 |
"""
|
4 |
-
from
|
|
|
5 |
|
6 |
-
from axolotl.prompt_tokenizers import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
@@ -13,6 +84,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
13 |
tokenizer,
|
14 |
cfg.train_on_inputs,
|
15 |
cfg.sequence_len,
|
|
|
16 |
)
|
17 |
if ds_cfg and "field" in ds_cfg:
|
18 |
strat.field = ds_cfg["field"]
|
|
|
1 |
"""
|
2 |
Basic completion text
|
3 |
"""
|
4 |
+
from collections import defaultdict
|
5 |
+
from typing import Any, Dict, Generator, Optional, Tuple
|
6 |
|
7 |
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
8 |
+
|
9 |
+
|
10 |
+
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
11 |
+
"""
|
12 |
+
Tokenizing strategy for Completion prompts.
|
13 |
+
"""
|
14 |
+
|
15 |
+
_field: str = "text"
|
16 |
+
|
17 |
+
def __init__(self, *args, max_length=None, **kwargs):
|
18 |
+
super().__init__(*args, **kwargs)
|
19 |
+
if max_length is not None:
|
20 |
+
self.max_length = max_length
|
21 |
+
|
22 |
+
@property
|
23 |
+
def supports_batched(self):
|
24 |
+
return True
|
25 |
+
|
26 |
+
@property
|
27 |
+
def field(self) -> str:
|
28 |
+
return self._field
|
29 |
+
|
30 |
+
@field.setter
|
31 |
+
def field(self, new_field: str):
|
32 |
+
self._field = new_field
|
33 |
+
|
34 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
35 |
+
return (
|
36 |
+
prompt[self.field],
|
37 |
+
"",
|
38 |
+
"",
|
39 |
+
)
|
40 |
+
|
41 |
+
def tokenize_prompt(self, prompt):
|
42 |
+
res = defaultdict(lambda: [])
|
43 |
+
feature_names = list(prompt.keys())
|
44 |
+
for row in zip(*prompt.values()):
|
45 |
+
prompt_row = dict(zip(feature_names, row))
|
46 |
+
(
|
47 |
+
instruction,
|
48 |
+
_,
|
49 |
+
_,
|
50 |
+
) = self.parse_instruction_fields(prompt_row)
|
51 |
+
|
52 |
+
full_prompt = self._build_full_prompt(instruction, None, None)
|
53 |
+
tokenized_full_prompt = self._tokenize(full_prompt)
|
54 |
+
|
55 |
+
for key, val in tokenized_full_prompt.items():
|
56 |
+
for i in range(0, len(val), self.sequence_len):
|
57 |
+
res[key].append(val[i : i + self.sequence_len])
|
58 |
+
|
59 |
+
return dict(res)
|
60 |
+
|
61 |
+
def _build_full_prompt(
|
62 |
+
self, instruction, input, response
|
63 |
+
): # pylint: disable=redefined-builtin
|
64 |
+
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
65 |
+
|
66 |
+
|
67 |
+
class CompletionPrompter:
|
68 |
+
"""
|
69 |
+
Prompter for completion
|
70 |
+
"""
|
71 |
+
|
72 |
+
def build_prompt(
|
73 |
+
self,
|
74 |
+
instruction: str,
|
75 |
+
input=None, # pylint: disable=redefined-builtin, unused-argument
|
76 |
+
output=None, # pylint: disable=unused-argument
|
77 |
+
) -> Generator[str, None, None]:
|
78 |
+
yield instruction
|
79 |
|
80 |
|
81 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
|
84 |
tokenizer,
|
85 |
cfg.train_on_inputs,
|
86 |
cfg.sequence_len,
|
87 |
+
max_length=cfg.sequence_len * 64,
|
88 |
)
|
89 |
if ds_cfg and "field" in ds_cfg:
|
90 |
strat.field = ds_cfg["field"]
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -41,11 +41,16 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
41 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
42 |
self.train_on_inputs = train_on_inputs
|
43 |
self.sequence_len = sequence_len
|
|
|
44 |
|
45 |
@abc.abstractmethod
|
46 |
def tokenize_prompt(self, prompt):
|
47 |
pass
|
48 |
|
|
|
|
|
|
|
|
|
49 |
@functools.lru_cache(maxsize=128)
|
50 |
def _get_user_token(self):
|
51 |
try:
|
@@ -77,7 +82,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
77 |
result = self.tokenizer(
|
78 |
prompt,
|
79 |
truncation=True,
|
80 |
-
max_length=self.
|
81 |
padding=False,
|
82 |
return_tensors=None,
|
83 |
)
|
@@ -86,7 +91,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|
86 |
if (
|
87 |
len(result["input_ids"]) > 0
|
88 |
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
89 |
-
and len(result["input_ids"]) < self.
|
90 |
and add_eos_token
|
91 |
):
|
92 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
@@ -247,46 +252,6 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
247 |
)
|
248 |
|
249 |
|
250 |
-
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
251 |
-
"""
|
252 |
-
Tokenizing strategy for Completion prompts.
|
253 |
-
"""
|
254 |
-
|
255 |
-
_field: str = "text"
|
256 |
-
|
257 |
-
@property
|
258 |
-
def field(self) -> str:
|
259 |
-
return self._field
|
260 |
-
|
261 |
-
@field.setter
|
262 |
-
def field(self, new_field: str):
|
263 |
-
self._field = new_field
|
264 |
-
|
265 |
-
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
266 |
-
return (
|
267 |
-
prompt[self.field],
|
268 |
-
"",
|
269 |
-
"",
|
270 |
-
)
|
271 |
-
|
272 |
-
def tokenize_prompt(self, prompt):
|
273 |
-
(
|
274 |
-
instruction,
|
275 |
-
_,
|
276 |
-
_,
|
277 |
-
) = self.parse_instruction_fields(prompt)
|
278 |
-
|
279 |
-
full_prompt = self._build_full_prompt(instruction, None, None)
|
280 |
-
tokenized_full_prompt = self._tokenize(full_prompt)
|
281 |
-
|
282 |
-
return tokenized_full_prompt
|
283 |
-
|
284 |
-
def _build_full_prompt(
|
285 |
-
self, instruction, input, response
|
286 |
-
): # pylint: disable=redefined-builtin
|
287 |
-
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
288 |
-
|
289 |
-
|
290 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
291 |
"""
|
292 |
Tokenizing strategy for Reflection prompts.
|
|
|
41 |
self.tokenizer: PreTrainedTokenizer = tokenizer
|
42 |
self.train_on_inputs = train_on_inputs
|
43 |
self.sequence_len = sequence_len
|
44 |
+
self.max_length = sequence_len
|
45 |
|
46 |
@abc.abstractmethod
|
47 |
def tokenize_prompt(self, prompt):
|
48 |
pass
|
49 |
|
50 |
+
@property
|
51 |
+
def supports_batched(self):
|
52 |
+
return False
|
53 |
+
|
54 |
@functools.lru_cache(maxsize=128)
|
55 |
def _get_user_token(self):
|
56 |
try:
|
|
|
82 |
result = self.tokenizer(
|
83 |
prompt,
|
84 |
truncation=True,
|
85 |
+
max_length=self.max_length,
|
86 |
padding=False,
|
87 |
return_tensors=None,
|
88 |
)
|
|
|
91 |
if (
|
92 |
len(result["input_ids"]) > 0
|
93 |
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
94 |
+
and len(result["input_ids"]) < self.max_length
|
95 |
and add_eos_token
|
96 |
):
|
97 |
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
|
252 |
)
|
253 |
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
256 |
"""
|
257 |
Tokenizing strategy for Reflection prompts.
|
src/axolotl/prompters.py
CHANGED
@@ -135,20 +135,6 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
135 |
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
136 |
|
137 |
|
138 |
-
class CompletionPrompter:
|
139 |
-
"""
|
140 |
-
Prompter for completion
|
141 |
-
"""
|
142 |
-
|
143 |
-
def build_prompt(
|
144 |
-
self,
|
145 |
-
instruction: str,
|
146 |
-
input=None, # pylint: disable=redefined-builtin, unused-argument
|
147 |
-
output=None, # pylint: disable=unused-argument
|
148 |
-
) -> Generator[str, None, None]:
|
149 |
-
yield instruction
|
150 |
-
|
151 |
-
|
152 |
class GPTeacherPrompter(AlpacaPrompter):
|
153 |
"""
|
154 |
Prompter for GPTeacher
|
|
|
135 |
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
class GPTeacherPrompter(AlpacaPrompter):
|
139 |
"""
|
140 |
Prompter for GPTeacher
|