File size: 1,736 Bytes
c7cf381 586bd8d c7cf381 586bd8d c7cf381 586bd8d c7cf381 586bd8d c7cf381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
self.text_column = text_column
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt[self.text_column])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64,
)
return strat
|