winglian's picture
plain input/output prompt strategy w/o chat templates (#1346)
4d09b42 unverified
raw
history blame
1.7 kB
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
class RawInputOutputStrategy(PromptTokenizingStrategy):
"""Prompt Strategy class for input/output pairs"""
def __init__(self, *args, eos_token=None, **kwargs):
super().__init__(*args, **kwargs)
self.eos_token = eos_token
if not eos_token:
self.eos_token = self.tokenizer.eos_token
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
input_ids = []
labels = []
for label, text in self.prompter.build_prompt(prompt["segments"]):
tokenized_output = self.tokenizer(
text, add_special_tokens=False, return_tensors=None
)["input_ids"]
input_ids += tokenized_output
if label or self.train_on_inputs:
labels += tokenized_output
else:
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
class RawInputOutputPrompter(Prompter):
"""prompter for raw i/o data"""
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
for segment in source:
yield segment["label"], segment["text"]
def load(tokenizer, cfg):
return RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)