from transformers import GPT2Tokenizer from typing import Optional, List, Union, Dict PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/vocab.json", }, "merges_file": { "yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/merges.txt", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "yuchenxie/arlow-gpt": 1024, } class ArlowGPTTokenizer(GPT2Tokenizer): """ ArlowGPT tokenizer, derived from GPT2Tokenizer with custom configurations """ vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, merges_file, errors="replace", unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", pad_token=None, add_prefix_space=False, add_special_tokens=True, padding_side="right", **kwargs ): super().__init__( vocab_file=vocab_file, merges_file=merges_file, errors=errors, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, add_special_tokens=add_special_tokens, padding_side=padding_side, **kwargs ) def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is None: return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: if token_ids_1 is None: return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]