|
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" |
|
|
|
import logging |
|
from typing import Tuple |
|
|
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy |
|
from axolotl.prompters import AlpacaPrompter |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
IGNORE_TOKEN_ID = -100 |
|
|
|
|
|
|
|
|
|
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for the Metharme models |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: |
|
return (prompt["prompt"], "", prompt["generation"]) |
|
|
|
def _tokenize( |
|
self, |
|
prompt: str, |
|
add_eos_token: bool = True, |
|
strip_bos_token: bool = False, |
|
num_eos_tokens: int = 3, |
|
): |
|
result = self.tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=self.sequence_len, |
|
padding=False, |
|
return_tensors=None, |
|
) |
|
if len(result["input_ids"]) == 0: |
|
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") |
|
|
|
if result["input_ids"][-1] == self.tokenizer.eos_token_id: |
|
num_eos_tokens -= 1 |
|
|
|
if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0: |
|
for _ in range(num_eos_tokens): |
|
if len(result["input_ids"]) < self.sequence_len: |
|
result["input_ids"].append(self.tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: |
|
result["input_ids"] = result["input_ids"][1:] |
|
result["attention_mask"] = result["attention_mask"][1:] |
|
|
|
result["labels"] = result["input_ids"].copy() |
|
return result |
|
|
|
|
|
class MetharmePrompter(AlpacaPrompter): |
|
""" |
|
Prompter for the Metharme models. |
|
""" |
|
|
|
system_prompt = "" |
|
system_no_input_prompt = "" |
|
system_format = "" |
|
turn_format = "{instruction}" |
|
turn_no_input_format = "{instruction}" |
|
|
|
def __init__(self, *args, **kwargs): |
|
pass |
|
|
|
|
|
def load(tokenizer, cfg): |
|
return MetharmePromptTokenizingStrategy( |
|
MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len |
|
) |
|
|