|
from tokenizers import Tokenizer |
|
from tokenizers.models import WordLevel, BPE |
|
from tokenizers.trainers import WordLevelTrainer, BpeTrainer |
|
from tokenizers.pre_tokenizers import Whitespace, ByteLevel |
|
|
|
from pathlib import Path |
|
|
|
|
|
def get_all_sentences(ds, lang: str): |
|
for item in ds: |
|
yield item['translation'][lang] |
|
|
|
def get_or_build_local_tokenizer(config, ds, lang: str, tokenizer_type: str, force_build: bool = False) -> Tokenizer: |
|
tokenizer_path = Path(config['dataset']['tokenizer_file'].format(lang)) |
|
if not Path.exists(tokenizer_path) or force_build: |
|
if ds is None: |
|
raise ValueError("Cannot find local tokenizer, dataset given is None") |
|
|
|
if tokenizer_type == "WordLevel": |
|
tokenizer = Tokenizer(WordLevel(unk_token='<unk>')) |
|
tokenizer.pre_tokenizer = Whitespace() |
|
trainer = WordLevelTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2) |
|
elif tokenizer_type == "BPE": |
|
tokenizer = Tokenizer(BPE(unk_token='<unk>')) |
|
tokenizer.pre_tokenizer = Whitespace() |
|
trainer = BpeTrainer(special_tokens=['<unk>', '<pad>', '<sos>', '<eos>'], min_frequency=2) |
|
else: |
|
raise ValueError("Unsupported Tokenizer type") |
|
|
|
tokenizer.train_from_iterator( |
|
get_all_sentences(ds, lang), trainer=trainer |
|
) |
|
tokenizer.save(str(tokenizer_path)) |
|
else: |
|
tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
return tokenizer |