"""pretraining prompt strategies""" from typing import Generator from transformers import BatchEncoding from axolotl.prompt_tokenizers import PromptTokenizingStrategy class PretrainTokenizer: """basic tokenization class for pretraining""" def build_prompt(self, prompt) -> Generator[str, None, None]: yield prompt class PretrainTokenizationStrategy(PromptTokenizingStrategy): """handles tokenization for pretraining with strides""" @property def supports_batched(self): return True def __init__(self, *args, max_length=None, text_column="text", **kwargs): super().__init__(*args, **kwargs) if max_length: self.max_length = max_length self.text_column = text_column def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: res = self.tokenizer( prompt, truncation=True, max_length=self.max_length - 1, add_special_tokens=True, return_overflowing_tokens=True, stride=256, ) res["input_ids"] = [ seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"] ] res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]] return res def tokenize_prompt(self, prompt): return self._tokenize(prompt[self.text_column]) def load(tokenizer, cfg): strat = PretrainTokenizationStrategy( PretrainTokenizer(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, text_column=cfg.pretraining_dataset[0]["text_column"] or "text", max_length=cfg.sequence_len * 64, ) return strat