|
"""Module containing PromptTokenizingStrategy and Prompter classes""" |
|
|
|
import abc |
|
import copy |
|
import logging |
|
from typing import Dict, List, Tuple, Union |
|
|
|
from fastchat.conversation import Conversation |
|
from transformers import BatchEncoding, PreTrainedTokenizer |
|
|
|
from axolotl.monkeypatch.fastchat_conversation_turns import ( |
|
add_get_turns_to_conversation, |
|
) |
|
from axolotl.prompters import IGNORE_TOKEN_ID |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
IGNORE_INDEX = -100 |
|
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" |
|
LLAMA_DEFAULT_EOS_TOKEN = "</s>" |
|
LLAMA_DEFAULT_BOS_TOKEN = "<s>" |
|
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" |
|
|
|
add_get_turns_to_conversation() |
|
|
|
|
|
class InvalidDataException(Exception): |
|
""" |
|
Exception raised when the data is invalid |
|
""" |
|
|
|
|
|
class PromptTokenizingStrategy(abc.ABC): |
|
""" |
|
Abstract class for tokenizing strategies |
|
""" |
|
|
|
def __init__( |
|
self, |
|
prompter, |
|
tokenizer, |
|
train_on_inputs: bool = False, |
|
sequence_len: int = 2048, |
|
): |
|
self.prompter = prompter |
|
self.tokenizer: PreTrainedTokenizer = tokenizer |
|
self.train_on_inputs = train_on_inputs |
|
|
|
|
|
self.sequence_len = sequence_len |
|
self.max_length = sequence_len |
|
|
|
@abc.abstractmethod |
|
def tokenize_prompt(self, prompt): |
|
pass |
|
|
|
@property |
|
def supports_batched(self): |
|
return False |
|
|
|
def _tokenize( |
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False |
|
) -> BatchEncoding: |
|
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) |
|
if not prompt: |
|
LOG.warning("Empty text requested for tokenization.") |
|
return empty |
|
|
|
result = self.tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=self.max_length, |
|
padding=False, |
|
return_tensors=None, |
|
) |
|
if len(result["input_ids"]) == 0: |
|
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") |
|
return empty |
|
|
|
if ( |
|
result["input_ids"][-1] != self.tokenizer.eos_token_id |
|
and len(result["input_ids"]) < self.max_length |
|
and add_eos_token |
|
): |
|
result["input_ids"].append(self.tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: |
|
result["input_ids"] = result["input_ids"][1:] |
|
result["attention_mask"] = result["attention_mask"][1:] |
|
|
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
|
|
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for instruction-based prompts. |
|
""" |
|
|
|
def parse_instruction_fields( |
|
self, prompt |
|
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: |
|
raise NotImplementedError |
|
|
|
def tokenize_prompt(self, prompt): |
|
( |
|
instruction, |
|
input, |
|
response, |
|
) = self.parse_instruction_fields(prompt) |
|
user_prompt = next( |
|
iter( |
|
self.prompter.build_prompt( |
|
instruction, |
|
input, |
|
) |
|
) |
|
) |
|
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) |
|
if not self.train_on_inputs: |
|
user_prompt_len = len(tokenized_prompt["input_ids"]) |
|
|
|
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len |
|
tokenized_res_prompt = self._tokenize( |
|
response, strip_bos_token=True, add_eos_token=True |
|
) |
|
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] |
|
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] |
|
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] |
|
|
|
return tokenized_prompt |
|
|
|
def _build_full_prompt( |
|
self, instruction, input, response |
|
): |
|
return next( |
|
iter( |
|
self.prompter.build_prompt( |
|
instruction, |
|
input, |
|
response, |
|
) |
|
) |
|
) |
|
|
|
|
|
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Alpaca prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["instruction"], |
|
prompt["input"] if "input" in prompt else "", |
|
prompt["output"], |
|
) |
|
|
|
|
|
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Alpaca Multiple Choice prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["question"], |
|
"\n".join(f'- "{choice}"' for choice in prompt["choices"]), |
|
prompt["solution"] if "solution" in prompt else prompt["explanation"], |
|
) |
|
|
|
|
|
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Jeopardy prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["question"], |
|
prompt["category"], |
|
"what is " + prompt["answer"], |
|
) |
|
|
|
|
|
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for OpenAssistant prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["INSTRUCTION"], |
|
"", |
|
prompt["RESPONSE"], |
|
) |
|
|
|
|
|
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for SummarizeTLDR prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["article"], |
|
"", |
|
prompt["summary"], |
|
) |
|
|
|
|
|
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for GPTeacher prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["instruction"], |
|
prompt["input"] if "input" in prompt else "", |
|
prompt["response"], |
|
) |
|
|
|
|
|
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for NomicGPT4All prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return ( |
|
prompt["prompt"], |
|
"", |
|
prompt["response"], |
|
) |
|
|
|
|
|
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Reflection prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: |
|
raise NotImplementedError |
|
|
|
def tokenize_prompt(self, prompt): |
|
|
|
( |
|
instruction, |
|
input, |
|
output, |
|
reflection, |
|
corrected, |
|
) = self.parse_instruction_fields(prompt) |
|
full_prompt = self._build_full_prompt( |
|
instruction, input, output, reflection, corrected |
|
) |
|
tokenized_full_prompt = self._tokenize(full_prompt) |
|
if not self.train_on_inputs: |
|
user_prompt = next( |
|
iter( |
|
self.prompter.build_prompt( |
|
instruction, |
|
input, |
|
) |
|
) |
|
) |
|
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False) |
|
user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
|
|
|
tokenized_full_prompt["labels"] = [ |
|
IGNORE_INDEX |
|
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] |
|
|
|
return tokenized_full_prompt |
|
|
|
def _build_full_prompt( |
|
self, instruction, input, output, reflection, corrected |
|
): |
|
return next( |
|
iter( |
|
self.prompter.build_prompt( |
|
instruction, |
|
input, |
|
output, |
|
reflection, |
|
corrected, |
|
) |
|
) |
|
) |
|
|
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): |
|
result = self.tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=self.sequence_len, |
|
padding=False, |
|
return_tensors=None, |
|
) |
|
if ( |
|
result["input_ids"][-1] != self.tokenizer.eos_token_id |
|
and len(result["input_ids"]) < self.sequence_len |
|
and add_eos_token |
|
): |
|
result["input_ids"].append(self.tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
|
|
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for Alpaca Reflection prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: |
|
return ( |
|
prompt["instruction"], |
|
prompt["input"] if "input" in prompt else "", |
|
prompt["output"], |
|
prompt["reflection"], |
|
prompt["corrected"], |
|
) |
|
|
|
|
|
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for ShareGPT prompts. |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
return prompt["conversations"] |
|
|
|
def tokenize_prompt(self, prompt): |
|
|
|
result, current_len = tokenize_prompt_default() |
|
conversation: Conversation = ( |
|
self.prompter._conversation.copy() |
|
) |
|
|
|
|
|
role_remap = [] |
|
if ( |
|
conversation.name == "vicuna_v1.1" |
|
and "roles" in prompt |
|
and len(prompt["roles"]) >= 2 |
|
): |
|
role_remap = [ |
|
{"from": conversation.roles[0], "to": prompt["roles"][0]}, |
|
{"from": conversation.roles[1], "to": prompt["roles"][1]}, |
|
] |
|
|
|
try: |
|
for _, part in enumerate( |
|
self.prompter.build_prompt(self.get_conversation_thread(prompt)) |
|
): |
|
if not isinstance(part, tuple): |
|
LOG.warning(f"expected tuple, got {part}") |
|
continue |
|
|
|
user, assistant = conversation.roles |
|
role, content = part |
|
|
|
|
|
if user in role: |
|
role = ( |
|
role.replace(role_remap[0]["from"], role_remap[0]["to"]) |
|
if role_remap |
|
else role |
|
) |
|
turn = role + content |
|
|
|
if not content.strip(): |
|
LOG.warning(f"user turn has empty text: {prompt}") |
|
res = self._tokenize( |
|
turn, |
|
add_eos_token=False, |
|
strip_bos_token=True, |
|
) |
|
if self.train_on_inputs: |
|
labels = copy.deepcopy(res["input_ids"]) |
|
else: |
|
|
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
|
elif assistant in role: |
|
role = ( |
|
role.replace(role_remap[1]["from"], role_remap[1]["to"]) |
|
if role_remap |
|
else role |
|
) |
|
turn = role + content |
|
|
|
if not content.strip(): |
|
LOG.warning(f"assistant turn has empty text: {prompt}") |
|
add_eos_token = not ( |
|
conversation.name == "chatml" |
|
and conversation.sep == self.tokenizer.eos_token |
|
) |
|
res = self._tokenize( |
|
turn, |
|
add_eos_token=add_eos_token, |
|
strip_bos_token=True, |
|
) |
|
role_res = self._tokenize( |
|
role.rstrip(), |
|
add_eos_token=False, |
|
strip_bos_token=True, |
|
) |
|
labels = copy.deepcopy(res["input_ids"]) |
|
if not self.train_on_inputs: |
|
|
|
len_role = len(role_res["input_ids"]) |
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min( |
|
len_role, len(labels) |
|
) |
|
elif role == "": |
|
turn = content |
|
|
|
res = self._tokenize( |
|
turn, add_eos_token=False, strip_bos_token=False |
|
) |
|
if self.train_on_inputs: |
|
labels = copy.deepcopy(res["input_ids"]) |
|
else: |
|
|
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) |
|
else: |
|
LOG.warning(f"unhandled role: {role}") |
|
continue |
|
|
|
|
|
result, current_len = parse_tokenized_to_result( |
|
result, |
|
current_len, |
|
res, |
|
labels, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
) |
|
return result |
|
except (KeyError, AssertionError, IndexError) as err: |
|
raise InvalidDataException(str(err)) from err |
|
|
|
|
|
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: |
|
""" |
|
Returns the default values for the tokenize prompt function |
|
""" |
|
|
|
result: Dict[str, List[int]] = { |
|
"input_ids": [], |
|
"attention_mask": [], |
|
"labels": [], |
|
} |
|
current_len = 0 |
|
return result, current_len |
|
|
|
|
|
def parse_tokenized_to_result( |
|
result: Dict[str, List[int]], |
|
current_len: int, |
|
res: Dict[str, List[int]], |
|
labels: List[int], |
|
pad_token_id: Union[int, None] = None, |
|
) -> Tuple[Dict[str, List[int]], int]: |
|
""" |
|
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result |
|
""" |
|
|
|
input_ids = res["input_ids"] |
|
input_len = len(input_ids) |
|
result["input_ids"][current_len : current_len + input_len] = input_ids |
|
result["attention_mask"][current_len : current_len + input_len] = [ |
|
1 if x != pad_token_id else 0 for x in input_ids |
|
] |
|
result["labels"][current_len : current_len + input_len] = labels |
|
current_len += input_len |
|
|
|
return result, current_len |
|
|