ReactSeq / onmt /inputters /inputter.py
Oopstom's picture
Upload 313 files
c668e80 verified
# -*- coding: utf-8 -*-
import os
import math
import codecs
import torch
import pyonmttok
from onmt.constants import DefaultTokens
from onmt.transforms import TransformPipe
class IterOnDevice(torch.utils.data.IterableDataset):
"""Sent items from `iterable` on `device_id` and yield."""
def __init__(self, iterable, device_id):
super(IterOnDevice).__init__()
self.iterable = iterable
self.device_id = device_id
# temporary as long as translation_server and scoring_preparator still use lists
if hasattr(iterable, "transforms"):
self.transform = TransformPipe.build_from(
[iterable.transforms[name] for name in iterable.transforms]
)
@staticmethod
def batch_to_device(tensor_batch, device_id):
"""Move `batch` to `device_id`, cpu if `device_id` < 0."""
device = torch.device(device_id) if device_id >= 0 else torch.device("cpu")
for key in tensor_batch.keys():
if key != "src_ex_vocab":
tensor_batch[key] = tensor_batch[key].to(device)
def __iter__(self):
for tensor_batch in self.iterable:
self.batch_to_device(tensor_batch, self.device_id)
yield tensor_batch
def build_vocab(opt, specials):
"""Build vocabs dict to be stored in the checkpoint
based on vocab files having each line [token, count]
Args:
opt: src_vocab, tgt_vocab, n_src_feats
Return:
vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
'src_feats' : [pyonmttok.Vocab, ...]},
'data_task': seq2seq or lm
'decoder_start_token': DefaultTokens.BOS
}
"""
def _pad_vocab_to_multiple(vocab, multiple):
vocab_size = len(vocab)
if vocab_size % multiple == 0:
return vocab
target_size = int(math.ceil(vocab_size / multiple)) * multiple
for i in range(target_size - vocab_size):
vocab.add_token(DefaultTokens.VOCAB_PAD + str(i))
return vocab
default_specials = opt.default_specials
vocabs = {}
src_vocab = _read_vocab_file(opt.src_vocab, opt.src_words_min_frequency)
src_specials = [
item for item in (default_specials + specials["src"]) if item not in src_vocab
]
if DefaultTokens.SEP in src_specials and (
"<0x0A>" in src_vocab or "Ċ" in src_vocab
):
# this is hack: if the special separator ⦅newline⦆is returned because of the
# "docify" transform.get_specials we don't add it if the corresponding newline code
# is already included in the sentencepiece or BPE-with-gpt2-pretok.
src_specials.remove(DefaultTokens.SEP)
src_vocab = pyonmttok.build_vocab_from_tokens(
src_vocab, maximum_size=opt.src_vocab_size, special_tokens=src_specials
)
src_vocab.default_id = src_vocab[DefaultTokens.UNK]
if opt.vocab_size_multiple > 1:
src_vocab = _pad_vocab_to_multiple(src_vocab, opt.vocab_size_multiple)
vocabs["src"] = src_vocab
if opt.share_vocab:
vocabs["tgt"] = src_vocab
else:
tgt_vocab = _read_vocab_file(opt.tgt_vocab, opt.tgt_words_min_frequency)
tgt_specials = [
item
for item in (default_specials + specials["tgt"])
if item not in tgt_vocab
]
if DefaultTokens.SEP in tgt_specials and (
"<0x0A>" in tgt_vocab or "Ċ" in src_vocab
):
tgt_specials.remove(DefaultTokens.SEP)
tgt_vocab = pyonmttok.build_vocab_from_tokens(
tgt_vocab, maximum_size=opt.tgt_vocab_size, special_tokens=tgt_specials
)
tgt_vocab.default_id = tgt_vocab[DefaultTokens.UNK]
if opt.vocab_size_multiple > 1:
tgt_vocab = _pad_vocab_to_multiple(tgt_vocab, opt.vocab_size_multiple)
vocabs["tgt"] = tgt_vocab
if opt.n_src_feats > 0:
src_feats_vocabs = []
for i in range(opt.n_src_feats):
src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1)
src_f_vocab = pyonmttok.build_vocab_from_tokens(
src_f_vocab,
maximum_size=0,
minimum_frequency=1,
special_tokens=default_specials,
)
src_f_vocab.default_id = src_f_vocab[DefaultTokens.UNK]
if opt.vocab_size_multiple > 1:
src_f_vocab = _pad_vocab_to_multiple(
src_f_vocab, opt.vocab_size_multiple
)
src_feats_vocabs.append(src_f_vocab)
vocabs["src_feats"] = src_feats_vocabs
vocabs["data_task"] = opt.data_task
vocabs["decoder_start_token"] = opt.decoder_start_token
return vocabs
def _read_vocab_file(vocab_path, min_count):
"""Loads a vocabulary from the given path.
Args:
vocab_path (str): Path to utf-8 text file containing vocabulary.
Each token should be on a line, may followed with a count number
seperate by space if `with_count`. No extra whitespace is allowed.
min_count (int): retains only tokens with min_count frequency.
"""
if not os.path.exists(vocab_path):
raise RuntimeError("Vocabulary not found at {}".format(vocab_path))
else:
with codecs.open(vocab_path, "rb", "utf-8") as f:
lines = [line.strip("\n") for line in f if line.strip("\n")]
first_line = lines[0].split(None, 1)
has_count = len(first_line) == 2 and first_line[-1].isdigit()
if has_count:
vocab = []
for line in lines:
if int(line.split(None, 1)[1]) >= min_count:
vocab.append(line.split(None, 1)[0])
else:
vocab = lines
return vocab
def vocabs_to_dict(vocabs):
"""
Convert a dict of pyonmttok vocabs
into a plain text dict to be saved in the checkpoint
"""
vocabs_dict = {}
vocabs_dict["src"] = vocabs["src"].ids_to_tokens
vocabs_dict["tgt"] = vocabs["tgt"].ids_to_tokens
if "src_feats" in vocabs.keys():
vocabs_dict["src_feats"] = [
feat_vocab.ids_to_tokens for feat_vocab in vocabs["src_feats"]
]
vocabs_dict["data_task"] = vocabs["data_task"]
if "decoder_start_token" in vocabs.keys():
vocabs_dict["decoder_start_token"] = vocabs["decoder_start_token"]
else:
vocabs_dict["decoder_start_token"] = DefaultTokens.BOS
return vocabs_dict
def dict_to_vocabs(vocabs_dict):
"""
Convert a dict formatted vocabs (as stored in a checkpoint)
into a dict of pyonmttok vocabs objects.
"""
vocabs = {}
vocabs["data_task"] = vocabs_dict["data_task"]
if "decoder_start_token" in vocabs_dict.keys():
vocabs["decoder_start_token"] = vocabs_dict["decoder_start_token"]
else:
vocabs["decoder_start_token"] = DefaultTokens.BOS
vocabs["src"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["src"])
if vocabs_dict["src"] == vocabs_dict["tgt"]:
vocabs["tgt"] = vocabs["src"]
else:
vocabs["tgt"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["tgt"])
if "src_feats" in vocabs_dict.keys():
vocabs["src_feats"] = []
for feat_vocab in vocabs_dict["src_feats"]:
vocabs["src_feats"].append(pyonmttok.build_vocab_from_tokens(feat_vocab))
return vocabs