"""Module for plain input/output prompt pairs""" from typing import Generator, Tuple from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter class RawInputOutputStrategy(PromptTokenizingStrategy): """Prompt Strategy class for input/output pairs""" def __init__(self, *args, eos_token=None, **kwargs): super().__init__(*args, **kwargs) self.eos_token = eos_token if not eos_token: self.eos_token = self.tokenizer.eos_token def tokenize_prompt(self, prompt): # pylint: disable=duplicate-code input_ids = [] labels = [] for label, text in self.prompter.build_prompt(prompt["segments"]): tokenized_output = self.tokenizer( text, add_special_tokens=False, return_tensors=None )["input_ids"] input_ids += tokenized_output if label or self.train_on_inputs: labels += tokenized_output else: labels += [IGNORE_TOKEN_ID] * len(tokenized_output) tokenized_prompt = { "input_ids": input_ids, "labels": labels, "attention_mask": [1] * len(input_ids), } return tokenized_prompt class RawInputOutputPrompter(Prompter): """prompter for raw i/o data""" def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]: for segment in source: yield segment["label"], segment["text"] def load(tokenizer, cfg): return RawInputOutputStrategy( RawInputOutputPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, )