ru-word-stress-transformer / char_tokenizer.py
IlyaGusev's picture
fix
471dcd1
raw
history blame
3.49 kB
import os
from typing import Optional, Tuple, List
from collections import OrderedDict
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
def load_vocab(vocab_file):
vocab = OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
class CharTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.txt"}
def __init__(
self,
vocab_file=None,
pad_token="[PAD]",
unk_token="[UNK]",
bos_token="[BOS]",
eos_token="[EOS]",
*args,
**kwargs
):
super().__init__(
pad_token=pad_token,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs
)
if not vocab_file or not os.path.isfile(vocab_file):
self.vocab = OrderedDict()
self.ids_to_tokens = OrderedDict()
else:
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
def train(self, file_path):
vocab = set()
with open(file_path) as r:
for line in r:
word = line.strip()
vocab |= set(word)
vocab = list(vocab)
vocab.sort()
special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
vocab = special_tokens + vocab
for i, ch in enumerate(vocab):
self.vocab[ch] = i
self.ids_to_tokens = vocab
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return self.vocab
def _convert_token_to_id(self, token):
return self.vocab.get(token)
def _convert_id_to_token(self, index):
return self.ids_to_tokens[index]
def _tokenize(self, text):
return list(text)
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None
) -> List[int]:
bos = [self.bos_token_id]
eos = [self.eos_token_id]
return bos + token_ids_0 + eos
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None
) -> List[int]:
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None
) -> List[int]:
return (len(token_ids_0) + 2) * [0]
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None
) -> Tuple[str]:
assert os.path.isdir(save_directory)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
self.vocab_files_names["vocab_file"]
)
index = 0
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
assert index == token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)