|
""" |
|
Basic completion text |
|
""" |
|
from collections import defaultdict |
|
from typing import Any, Dict, Generator, Optional, Tuple |
|
|
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy |
|
|
|
|
|
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Completion prompts. |
|
""" |
|
|
|
_field: str = "text" |
|
|
|
def __init__(self, *args, max_length=None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if max_length is not None: |
|
self.max_length = max_length |
|
|
|
@property |
|
def supports_batched(self): |
|
return True |
|
|
|
@property |
|
def field(self) -> str: |
|
return self._field |
|
|
|
@field.setter |
|
def field(self, new_field: str): |
|
self._field = new_field |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt[self.field], |
|
"", |
|
"", |
|
) |
|
|
|
def tokenize_prompt(self, prompt): |
|
res = defaultdict(lambda: []) |
|
feature_names = list(prompt.keys()) |
|
for row in zip(*prompt.values()): |
|
prompt_row = dict(zip(feature_names, row)) |
|
( |
|
instruction, |
|
_, |
|
_, |
|
) = self.parse_instruction_fields(prompt_row) |
|
|
|
full_prompt = self._build_full_prompt(instruction, None, None) |
|
tokenized_full_prompt = self._tokenize(full_prompt) |
|
|
|
for key, val in tokenized_full_prompt.items(): |
|
for i in range(0, len(val), self.sequence_len): |
|
res[key].append(val[i : i + self.sequence_len]) |
|
|
|
return dict(res) |
|
|
|
def _build_full_prompt( |
|
self, instruction, input, response |
|
): |
|
return next(iter(self.prompter.build_prompt(instruction, input, response))) |
|
|
|
|
|
class CompletionPrompter: |
|
""" |
|
Prompter for completion |
|
""" |
|
|
|
def build_prompt( |
|
self, |
|
instruction: str, |
|
input=None, |
|
output=None, |
|
) -> Generator[str, None, None]: |
|
yield instruction |
|
|
|
|
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
strat = CompletionPromptTokenizingStrategy( |
|
CompletionPrompter(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
max_length=cfg.sequence_len * 64, |
|
) |
|
if ds_cfg and "field" in ds_cfg: |
|
strat.field = ds_cfg["field"] |
|
|
|
return strat |
|
|