GPT-2V / tokenization_arlow_gpt.py
yuchenxie's picture
Create tokenization_arlow_gpt.py
70e4476 verified
from transformers import GPT2Tokenizer
from typing import Optional, List, Union, Dict
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/vocab.json",
},
"merges_file": {
"yuchenxie/arlow-gpt": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"yuchenxie/arlow-gpt": 1024,
}
class ArlowGPTTokenizer(GPT2Tokenizer):
"""
ArlowGPT tokenizer, derived from GPT2Tokenizer with custom configurations
"""
vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token=None,
add_prefix_space=False,
add_special_tokens=True,
padding_side="right",
**kwargs
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
padding_side=padding_side,
**kwargs
)
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
if token_ids_1 is None:
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]