|
|
|
""" Translator Class and builder """ |
|
import codecs |
|
import os |
|
import time |
|
import numpy as np |
|
from itertools import count, zip_longest |
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
from onmt.constants import DefaultTokens |
|
import onmt.model_builder |
|
import onmt.decoders.ensemble |
|
from onmt.translate.beam_search import BeamSearch, BeamSearchLM |
|
from onmt.translate.greedy_search import GreedySearch, GreedySearchLM |
|
from onmt.utils.misc import tile, set_random_seed, report_matrix |
|
from onmt.utils.alignment import extract_alignment, build_align_pharaoh |
|
from onmt.modules.copy_generator import collapse_copy_scores |
|
from onmt.constants import ModelTask |
|
|
|
|
|
def build_translator(opt, device_id=0, report_score=True, logger=None, out_file=None): |
|
if out_file is None: |
|
out_file = codecs.open(opt.output, "w+", "utf-8") |
|
|
|
load_test_model = ( |
|
onmt.decoders.ensemble.load_test_model |
|
if len(opt.models) > 1 |
|
else onmt.model_builder.load_test_model |
|
) |
|
|
|
vocabs, model, model_opt = load_test_model(opt, device_id) |
|
|
|
scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt) |
|
|
|
if model_opt.model_task == ModelTask.LANGUAGE_MODEL: |
|
translator = GeneratorLM.from_opt( |
|
model, |
|
vocabs, |
|
opt, |
|
model_opt, |
|
global_scorer=scorer, |
|
out_file=out_file, |
|
report_align=opt.report_align, |
|
report_score=report_score, |
|
logger=logger, |
|
) |
|
else: |
|
translator = Translator.from_opt( |
|
model, |
|
vocabs, |
|
opt, |
|
model_opt, |
|
global_scorer=scorer, |
|
out_file=out_file, |
|
report_align=opt.report_align, |
|
report_score=report_score, |
|
logger=logger, |
|
) |
|
return translator |
|
|
|
|
|
class Inference(object): |
|
"""Translate a batch of sentences with a saved model. |
|
|
|
Args: |
|
model (onmt.modules.NMTModel): NMT model to use for translation |
|
vocabs (dict[str, Vocab]): A dict |
|
mapping each side's Vocab. |
|
gpu (int): GPU device. Set to negative for no GPU. |
|
n_best (int): How many beams to wait for. |
|
min_length (int): See |
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`. |
|
max_length (int): See |
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`. |
|
beam_size (int): Number of beams. |
|
random_sampling_topk (int): See |
|
:class:`onmt.translate.greedy_search.GreedySearch`. |
|
random_sampling_temp (float): See |
|
:class:`onmt.translate.greedy_search.GreedySearch`. |
|
stepwise_penalty (bool): Whether coverage penalty is applied every step |
|
or not. |
|
dump_beam (bool): Debugging option. |
|
block_ngram_repeat (int): See |
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`. |
|
ignore_when_blocking (set or frozenset): See |
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`. |
|
replace_unk (bool): Replace unknown token. |
|
tgt_file_prefix (bool): Force the predictions begin with provided -tgt. |
|
data_type (str): Source data type. |
|
verbose (bool): Print/log every translation. |
|
report_time (bool): Print/log total time/frequency. |
|
copy_attn (bool): Use copy attention. |
|
global_scorer (onmt.translate.GNMTGlobalScorer): Translation |
|
scoring/reranking object. |
|
out_file (TextIO or codecs.StreamReaderWriter): Output file. |
|
report_score (bool) : Whether to report scores |
|
logger (logging.Logger or NoneType): Logger. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
vocabs, |
|
gpu=-1, |
|
n_best=1, |
|
min_length=0, |
|
max_length=100, |
|
ratio=0.0, |
|
beam_size=30, |
|
random_sampling_topk=0, |
|
random_sampling_topp=0.0, |
|
random_sampling_temp=1.0, |
|
stepwise_penalty=None, |
|
dump_beam=False, |
|
block_ngram_repeat=0, |
|
ignore_when_blocking=frozenset(), |
|
replace_unk=False, |
|
ban_unk_token=False, |
|
tgt_file_prefix=False, |
|
phrase_table="", |
|
data_type="text", |
|
verbose=False, |
|
report_time=False, |
|
copy_attn=False, |
|
global_scorer=None, |
|
out_file=None, |
|
report_align=False, |
|
gold_align=False, |
|
report_score=True, |
|
logger=None, |
|
seed=-1, |
|
with_score=False, |
|
): |
|
self.model = model |
|
self.vocabs = vocabs |
|
self._tgt_vocab = vocabs["tgt"] |
|
self._tgt_eos_idx = vocabs["tgt"].lookup_token(DefaultTokens.EOS) |
|
self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD) |
|
self._tgt_bos_idx = vocabs["tgt"].lookup_token(DefaultTokens.BOS) |
|
self._tgt_unk_idx = vocabs["tgt"].lookup_token(DefaultTokens.UNK) |
|
self._tgt_sep_idx = vocabs["tgt"].lookup_token(DefaultTokens.SEP) |
|
self._tgt_start_with = vocabs["tgt"].lookup_token(vocabs["decoder_start_token"]) |
|
self._tgt_vocab_len = len(self._tgt_vocab) |
|
|
|
self._gpu = gpu |
|
self._use_cuda = gpu > -1 |
|
self._dev = ( |
|
torch.device("cuda", self._gpu) if self._use_cuda else torch.device("cpu") |
|
) |
|
|
|
self.n_best = n_best |
|
self.max_length = max_length |
|
|
|
self.beam_size = beam_size |
|
self.random_sampling_temp = random_sampling_temp |
|
self.sample_from_topk = random_sampling_topk |
|
self.sample_from_topp = random_sampling_topp |
|
|
|
self.min_length = min_length |
|
self.ban_unk_token = ban_unk_token |
|
self.ratio = ratio |
|
self.stepwise_penalty = stepwise_penalty |
|
self.dump_beam = dump_beam |
|
self.block_ngram_repeat = block_ngram_repeat |
|
self.ignore_when_blocking = ignore_when_blocking |
|
self._exclusion_idxs = {self._tgt_vocab[t] for t in self.ignore_when_blocking} |
|
self.replace_unk = replace_unk |
|
if self.replace_unk and not self.model.decoder.attentional: |
|
raise ValueError("replace_unk requires an attentional decoder.") |
|
self.tgt_file_prefix = tgt_file_prefix |
|
self.phrase_table = phrase_table |
|
self.data_type = data_type |
|
self.verbose = verbose |
|
self.report_time = report_time |
|
|
|
self.copy_attn = copy_attn |
|
|
|
self.global_scorer = global_scorer |
|
if self.global_scorer.has_cov_pen and not self.model.decoder.attentional: |
|
raise ValueError("Coverage penalty requires an attentional decoder.") |
|
self.out_file = out_file |
|
self.report_align = report_align |
|
self.gold_align = gold_align |
|
self.report_score = report_score |
|
self.logger = logger |
|
|
|
self.use_filter_pred = False |
|
self._filter_pred = None |
|
|
|
|
|
self.beam_trace = self.dump_beam != "" |
|
self.beam_accum = None |
|
if self.beam_trace: |
|
self.beam_accum = { |
|
"predicted_ids": [], |
|
"beam_parent_ids": [], |
|
"scores": [], |
|
"log_probs": [], |
|
} |
|
|
|
set_random_seed(seed, self._use_cuda) |
|
self.with_score = with_score |
|
|
|
@classmethod |
|
def from_opt( |
|
cls, |
|
model, |
|
vocabs, |
|
opt, |
|
model_opt, |
|
global_scorer=None, |
|
out_file=None, |
|
report_align=False, |
|
report_score=True, |
|
logger=None, |
|
): |
|
"""Alternate constructor. |
|
|
|
Args: |
|
model (onmt.modules.NMTModel): See :func:`__init__()`. |
|
vocabs (dict[str, Vocab]): See |
|
:func:`__init__()`. |
|
opt (argparse.Namespace): Command line options |
|
model_opt (argparse.Namespace): Command line options saved with |
|
the model checkpoint. |
|
global_scorer (onmt.translate.GNMTGlobalScorer): See |
|
:func:`__init__()`.. |
|
out_file (TextIO or codecs.StreamReaderWriter): See |
|
:func:`__init__()`. |
|
report_align (bool) : See :func:`__init__()`. |
|
report_score (bool) : See :func:`__init__()`. |
|
logger (logging.Logger or NoneType): See :func:`__init__()`. |
|
""" |
|
|
|
cls.validate_task(model_opt.model_task) |
|
|
|
return cls( |
|
model, |
|
vocabs, |
|
gpu=opt.gpu, |
|
n_best=opt.n_best, |
|
min_length=opt.min_length, |
|
max_length=opt.max_length, |
|
ratio=opt.ratio, |
|
beam_size=opt.beam_size, |
|
random_sampling_topk=opt.random_sampling_topk, |
|
random_sampling_topp=opt.random_sampling_topp, |
|
random_sampling_temp=opt.random_sampling_temp, |
|
stepwise_penalty=opt.stepwise_penalty, |
|
dump_beam=opt.dump_beam, |
|
block_ngram_repeat=opt.block_ngram_repeat, |
|
ignore_when_blocking=set(opt.ignore_when_blocking), |
|
replace_unk=opt.replace_unk, |
|
ban_unk_token=opt.ban_unk_token, |
|
tgt_file_prefix=opt.tgt_file_prefix, |
|
phrase_table=opt.phrase_table, |
|
data_type=opt.data_type, |
|
verbose=opt.verbose, |
|
report_time=opt.report_time, |
|
copy_attn=model_opt.copy_attn, |
|
global_scorer=global_scorer, |
|
out_file=out_file, |
|
report_align=report_align, |
|
gold_align=opt.gold_align, |
|
report_score=report_score, |
|
logger=logger, |
|
seed=opt.seed, |
|
with_score=opt.with_score, |
|
) |
|
|
|
def _log(self, msg): |
|
if self.logger: |
|
self.logger.info(msg) |
|
else: |
|
print(msg) |
|
|
|
def _gold_score( |
|
self, |
|
batch, |
|
enc_out, |
|
src_len, |
|
use_src_map, |
|
enc_final_hs, |
|
batch_size, |
|
src, |
|
): |
|
if "tgt" in batch.keys() and not self.tgt_file_prefix: |
|
gs = self._score_target( |
|
batch, |
|
enc_out, |
|
src_len, |
|
batch["src_map"] if use_src_map else None, |
|
) |
|
self.model.decoder.init_state(src, enc_out, enc_final_hs) |
|
else: |
|
gs = [0] * batch_size |
|
return gs |
|
|
|
def _translate( |
|
self, |
|
infer_iter, |
|
transform=None, |
|
attn_debug=False, |
|
align_debug=False, |
|
phrase_table="", |
|
): |
|
"""Translate content of ``src`` and get gold scores from ``tgt``. |
|
|
|
Args: |
|
infer_iter: tensored batch iterator from DynamicDatasetIter |
|
attn_debug (bool): enables the attention logging |
|
align_debug (bool): enables the word alignment logging |
|
|
|
Returns: |
|
(`list`, `list`) |
|
|
|
* all_scores is a list of `batch_size` lists of `n_best` scores |
|
* all_predictions is a list of `batch_size` lists |
|
of `n_best` predictions |
|
""" |
|
xlation_builder = onmt.translate.TranslationBuilder( |
|
infer_iter, |
|
self.vocabs, |
|
self.n_best, |
|
self.replace_unk, |
|
self.phrase_table, |
|
) |
|
|
|
|
|
counter = count(1) |
|
pred_score_total, pred_words_total = 0, 0 |
|
gold_score_total, gold_words_total = 0, 0 |
|
|
|
all_scores = [] |
|
all_predictions = [] |
|
|
|
start_time = time.time() |
|
|
|
def _maybe_retranslate(translations, batch): |
|
"""Here we handle the cases of mismatch in number of segments |
|
between source and target. We re-translate seg by seg.""" |
|
inds, perm = torch.sort(batch["indices"]) |
|
trans_copy = deepcopy(translations) |
|
inserted_so_far = 0 |
|
for j, trans in enumerate(translations): |
|
if trans.src_raw.count(DefaultTokens.SEP) != trans.pred_sents[0].count( |
|
DefaultTokens.SEP |
|
): |
|
self._log("Mismatch in number of ((newline))") |
|
|
|
|
|
|
|
|
|
|
|
|
|
idx = (trans.src == self._tgt_sep_idx).nonzero() |
|
sub_src = [] |
|
start_idx = 0 |
|
for i in range(len(idx)): |
|
end_idx = idx[i] |
|
sub_src.append(batch["src"][perm[j], start_idx:end_idx, :]) |
|
start_idx = end_idx + 1 |
|
end_idx = ( |
|
batch["src"][perm[j], :, 0].ne(self._tgt_pad_idx).sum() - 1 |
|
) |
|
sub_src.append(batch["src"][perm[j], start_idx:end_idx, :]) |
|
t_sub_src = pad_sequence( |
|
sub_src, batch_first=True, padding_value=self._tgt_pad_idx |
|
) |
|
t_sub_src_len = t_sub_src[:, :, 0].ne(self._tgt_pad_idx).sum(1) |
|
t_sub_src_ind = torch.tensor( |
|
[i for i in range(len(sub_src))], dtype=torch.int16 |
|
) |
|
device = batch["src"].device |
|
t_sub_batch = { |
|
"src": t_sub_src.to(device), |
|
"srclen": t_sub_src_len.to(device), |
|
"indices": t_sub_src_ind.to(device), |
|
} |
|
|
|
sub_data = self.translate_batch(t_sub_batch, attn_debug) |
|
sub_trans = xlation_builder.from_batch(sub_data) |
|
|
|
|
|
trans_copy[j + inserted_so_far] = sub_trans[0] |
|
for i in range(1, len(sub_src)): |
|
trans_copy.insert(j + i + inserted_so_far, sub_trans[i]) |
|
inserted_so_far += len(sub_src) - 1 |
|
return trans_copy |
|
|
|
for batch in infer_iter: |
|
batch_data = self.translate_batch(batch, attn_debug) |
|
|
|
translations = xlation_builder.from_batch(batch_data) |
|
if not isinstance(self, GeneratorLM): |
|
translations = _maybe_retranslate(translations, batch) |
|
|
|
for j, trans in enumerate(translations): |
|
all_scores += [trans.pred_scores[: self.n_best]] |
|
pred_score_total += trans.pred_scores[0] |
|
pred_words_total += len(trans.pred_sents[0]) |
|
if "tgt" in batch.keys(): |
|
gold_score_total += trans.gold_score |
|
gold_words_total += len(trans.gold_sent) + 1 |
|
|
|
n_best_preds = [ |
|
" ".join(pred) for pred in trans.pred_sents[: self.n_best] |
|
] |
|
|
|
n_best_scores = [ |
|
score.item() for score in trans.pred_scores[: self.n_best] |
|
] |
|
|
|
if self.report_align: |
|
align_pharaohs = [ |
|
build_align_pharaoh(align) |
|
for align in trans.word_aligns[: self.n_best] |
|
] |
|
n_best_preds_align = [ |
|
" ".join(align[0]) for align in align_pharaohs |
|
] |
|
n_best_preds = [ |
|
pred + DefaultTokens.ALIGNMENT_SEPARATOR + align |
|
for pred, align in zip(n_best_preds, n_best_preds_align) |
|
] |
|
|
|
if transform is not None: |
|
n_best_preds = transform.batch_apply_reverse(n_best_preds) |
|
|
|
all_predictions += [n_best_preds] |
|
|
|
out_all = [ |
|
pred + "\t" + str(score) |
|
for (pred, score) in zip(n_best_preds, n_best_scores) |
|
] |
|
|
|
if self.with_score: |
|
self.out_file.write("\n".join(out_all) + "\n") |
|
else: |
|
self.out_file.write("\n".join(n_best_preds) + "\n") |
|
self.out_file.flush() |
|
|
|
if self.verbose: |
|
sent_number = next(counter) |
|
output = trans.log(sent_number) |
|
if self.logger: |
|
self.logger.info(output) |
|
else: |
|
os.write(1, output.encode("utf-8")) |
|
|
|
if attn_debug: |
|
preds = trans.pred_sents[0] |
|
preds.append(DefaultTokens.EOS) |
|
attns = trans.attns[0].tolist() |
|
if self.data_type == "text": |
|
srcs = trans.src_raw |
|
else: |
|
srcs = [str(item) for item in range(len(attns[0]))] |
|
output = report_matrix(srcs, preds, attns) |
|
if self.logger: |
|
self.logger.info(output) |
|
else: |
|
os.write(1, output.encode("utf-8")) |
|
|
|
if align_debug: |
|
if self.gold_align: |
|
tgts = trans.gold_sent |
|
else: |
|
tgts = trans.pred_sents[0] |
|
align = trans.word_aligns[0].tolist() |
|
if self.data_type == "text": |
|
srcs = trans.src_raw |
|
else: |
|
srcs = [str(item) for item in range(len(align[0]))] |
|
output = report_matrix(srcs, tgts, align) |
|
if self.logger: |
|
self.logger.info(output) |
|
else: |
|
os.write(1, output.encode("utf-8")) |
|
|
|
end_time = time.time() |
|
|
|
if self.report_score: |
|
msg = self._report_score("PRED", pred_score_total, len(all_scores)) |
|
self._log(msg) |
|
if "tgt" in batch.keys() and not self.tgt_file_prefix: |
|
msg = self._report_score("GOLD", gold_score_total, len(all_scores)) |
|
self._log(msg) |
|
|
|
if self.report_time: |
|
total_time = end_time - start_time |
|
self._log("Total translation time (s): %.1f" % total_time) |
|
self._log( |
|
"Average translation time (ms): %.1f" |
|
% (total_time / len(all_predictions) * 1000) |
|
) |
|
self._log("Tokens per second: %.1f" % (pred_words_total / total_time)) |
|
|
|
if self.dump_beam: |
|
import json |
|
|
|
json.dump( |
|
self.translator.beam_accum, |
|
codecs.open(self.dump_beam, "w", "utf-8"), |
|
) |
|
|
|
return all_scores, all_predictions |
|
|
|
def _align_pad_prediction(self, predictions, bos, pad): |
|
""" |
|
Padding predictions in batch and add BOS. |
|
|
|
Args: |
|
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src |
|
sequence contain n_best tgt predictions all of which ended with |
|
eos id. |
|
bos (int): bos index to be used. |
|
pad (int): pad index to be used. |
|
|
|
Return: |
|
batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)` |
|
""" |
|
dtype, device = predictions[0][0].dtype, predictions[0][0].device |
|
flatten_tgt = [best.tolist() for bests in predictions for best in bests] |
|
paded_tgt = torch.tensor( |
|
list(zip_longest(*flatten_tgt, fillvalue=pad)), |
|
dtype=dtype, |
|
device=device, |
|
).T |
|
bos_tensor = torch.full([paded_tgt.size(0), 1], bos, dtype=dtype, device=device) |
|
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1) |
|
batched_nbest_predict = full_tgt.view( |
|
len(predictions), -1, full_tgt.size(-1) |
|
) |
|
return batched_nbest_predict |
|
|
|
def _report_score(self, name, score_total, nb_sentences): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if nb_sentences == 0: |
|
msg = "%s No translations" % (name,) |
|
else: |
|
score = score_total / nb_sentences |
|
ppl = np.exp(-score_total.item() / nb_sentences) |
|
msg = "%s SCORE: %.4f, %s PPL: %.2f NB SENTENCES: %d" % ( |
|
name, |
|
score, |
|
name, |
|
ppl, |
|
nb_sentences, |
|
) |
|
return msg |
|
|
|
def _decode_and_generate( |
|
self, |
|
decoder_in, |
|
enc_out, |
|
batch, |
|
src_len, |
|
src_map=None, |
|
step=None, |
|
batch_offset=None, |
|
): |
|
if self.copy_attn: |
|
|
|
decoder_in = decoder_in.masked_fill( |
|
decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
dec_out, dec_attn = self.model.decoder( |
|
decoder_in, |
|
enc_out, |
|
src_len=src_len, |
|
step=step, |
|
with_align=self.global_scorer.has_cov_pen, |
|
) |
|
|
|
|
|
if not self.copy_attn: |
|
if "std" in dec_attn: |
|
attn = dec_attn["std"] |
|
else: |
|
attn = None |
|
scores = self.model.generator(dec_out.squeeze(1)) |
|
log_probs = F.log_softmax(scores.to(torch.float32), dim=-1) |
|
|
|
|
|
else: |
|
attn = dec_attn["copy"] |
|
scores = self.model.generator( |
|
dec_out.view(-1, dec_out.size(2)), |
|
attn.view(-1, attn.size(2)), |
|
src_map, |
|
) |
|
|
|
if batch_offset is None: |
|
scores = scores.view(-1, len(batch["srclen"]), scores.size(-1)) |
|
scores = scores.transpose(0, 1).contiguous() |
|
else: |
|
scores = scores.view(-1, self.beam_size, scores.size(-1)) |
|
|
|
scores = collapse_copy_scores( |
|
scores, |
|
batch, |
|
self._tgt_vocab, |
|
batch_dim=0, |
|
batch_offset=batch_offset, |
|
) |
|
scores = scores.view(-1, decoder_in.size(1), scores.size(-1)) |
|
log_probs = scores.squeeze(1).log() |
|
|
|
|
|
return log_probs, attn |
|
|
|
def translate_batch(self, batch, attn_debug): |
|
"""Translate a batch of sentences.""" |
|
raise NotImplementedError |
|
|
|
def _score_target(self, batch, enc_out, src_len, src_map): |
|
raise NotImplementedError |
|
|
|
def report_results( |
|
self, |
|
gold_score, |
|
batch, |
|
batch_size, |
|
src, |
|
src_len, |
|
use_src_map, |
|
decode_strategy, |
|
): |
|
results = { |
|
"predictions": None, |
|
"scores": None, |
|
"attention": None, |
|
"batch": batch, |
|
"gold_score": gold_score, |
|
} |
|
|
|
results["scores"] = decode_strategy.scores |
|
results["predictions"] = decode_strategy.predictions |
|
results["attention"] = decode_strategy.attention |
|
if self.report_align: |
|
results["alignment"] = self._align_forward( |
|
batch, decode_strategy.predictions |
|
) |
|
else: |
|
results["alignment"] = [[] for _ in range(batch_size)] |
|
return results |
|
|
|
|
|
class Translator(Inference): |
|
@classmethod |
|
def validate_task(cls, task): |
|
if task != ModelTask.SEQ2SEQ: |
|
raise ValueError( |
|
f"Translator does not support task {task}." |
|
f" Tasks supported: {ModelTask.SEQ2SEQ}" |
|
) |
|
|
|
def _align_forward(self, batch, predictions): |
|
""" |
|
For a batch of input and its prediction, return a list of batch predict |
|
alignment src indice Tensor in size ``(batch, n_best,)``. |
|
""" |
|
|
|
|
|
if "tgt" in batch.keys() and self.gold_align: |
|
self._log("Computing alignments with gold target") |
|
batch_tgt_idxs = batch["tgt"].transpose(1, 2) |
|
else: |
|
batch_tgt_idxs = self._align_pad_prediction( |
|
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx |
|
) |
|
tgt_mask = ( |
|
batch_tgt_idxs.eq(self._tgt_pad_idx) |
|
| batch_tgt_idxs.eq(self._tgt_eos_idx) |
|
| batch_tgt_idxs.eq(self._tgt_bos_idx) |
|
) |
|
|
|
n_best = batch_tgt_idxs.size(1) |
|
|
|
src, enc_states, enc_out, src_len = self._run_encoder(batch) |
|
|
|
|
|
|
|
src = tile(src, n_best, dim=0) |
|
if enc_states is not None: |
|
|
|
|
|
|
|
enc_states = tile(enc_states, n_best, dim=0) |
|
if isinstance(enc_out, tuple): |
|
enc_out = tuple(tile(x, n_best, dim=0) for x in enc_out) |
|
else: |
|
enc_out = tile(enc_out, n_best, dim=0) |
|
src_len = tile(src_len, n_best) |
|
|
|
|
|
self.model.decoder.init_state(src, enc_out, enc_states) |
|
|
|
|
|
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1) |
|
dec_in = tgt[:-1].transpose(0, 1) |
|
|
|
_, attns = self.model.decoder(dec_in, enc_out, src_len=src_len, with_align=True) |
|
|
|
alignment_attn = attns["align"] |
|
|
|
align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1)) |
|
prediction_mask = align_tgt_mask[:, 1:] |
|
|
|
alignement = extract_alignment(alignment_attn, prediction_mask, src_len, n_best) |
|
return alignement |
|
|
|
def translate_batch(self, batch, attn_debug): |
|
"""Translate a batch of sentences.""" |
|
with torch.no_grad(): |
|
if self.sample_from_topk != 0 or self.sample_from_topp != 0: |
|
decode_strategy = GreedySearch( |
|
pad=self._tgt_pad_idx, |
|
bos=self._tgt_bos_idx, |
|
eos=self._tgt_eos_idx, |
|
unk=self._tgt_unk_idx, |
|
start=self._tgt_start_with, |
|
batch_size=len(batch["srclen"]), |
|
global_scorer=self.global_scorer, |
|
min_length=self.min_length, |
|
max_length=self.max_length, |
|
block_ngram_repeat=self.block_ngram_repeat, |
|
exclusion_tokens=self._exclusion_idxs, |
|
return_attention=attn_debug or self.replace_unk, |
|
sampling_temp=self.random_sampling_temp, |
|
keep_topk=self.sample_from_topk, |
|
keep_topp=self.sample_from_topp, |
|
beam_size=self.beam_size, |
|
ban_unk_token=self.ban_unk_token, |
|
) |
|
else: |
|
|
|
assert not self.dump_beam |
|
decode_strategy = BeamSearch( |
|
self.beam_size, |
|
batch_size=len(batch["srclen"]), |
|
pad=self._tgt_pad_idx, |
|
bos=self._tgt_bos_idx, |
|
eos=self._tgt_eos_idx, |
|
unk=self._tgt_unk_idx, |
|
start=self._tgt_start_with, |
|
n_best=self.n_best, |
|
global_scorer=self.global_scorer, |
|
min_length=self.min_length, |
|
max_length=self.max_length, |
|
return_attention=attn_debug or self.replace_unk, |
|
block_ngram_repeat=self.block_ngram_repeat, |
|
exclusion_tokens=self._exclusion_idxs, |
|
stepwise_penalty=self.stepwise_penalty, |
|
ratio=self.ratio, |
|
ban_unk_token=self.ban_unk_token, |
|
) |
|
return self._translate_batch_with_strategy(batch, decode_strategy) |
|
|
|
def _run_encoder(self, batch): |
|
src = batch["src"] |
|
src_len = batch["srclen"] |
|
batch_size = len(batch["srclen"]) |
|
|
|
enc_out, enc_final_hs, src_len = self.model.encoder(src, src_len) |
|
|
|
if src_len is None: |
|
assert not isinstance( |
|
enc_out, tuple |
|
), "Ensemble decoding only supported for text data" |
|
src_len = ( |
|
torch.Tensor(batch_size).type_as(enc_out).long().fill_(enc_out.size(1)) |
|
) |
|
return src, enc_final_hs, enc_out, src_len |
|
|
|
def _translate_batch_with_strategy(self, batch, decode_strategy): |
|
"""Translate a batch of sentences step by step using cache. |
|
|
|
Args: |
|
batch: a batch of sentences, yield by data iterator. |
|
decode_strategy (DecodeStrategy): A decode strategy to use for |
|
generate translation step by step. |
|
|
|
Returns: |
|
results (dict): The translation results. |
|
""" |
|
|
|
use_src_map = self.copy_attn |
|
parallel_paths = decode_strategy.parallel_paths |
|
|
|
batch_size = len(batch["srclen"]) |
|
|
|
|
|
src, enc_final_hs, enc_out, src_len = self._run_encoder(batch) |
|
|
|
self.model.decoder.init_state(src, enc_out, enc_final_hs) |
|
|
|
gold_score = self._gold_score( |
|
batch, |
|
enc_out, |
|
src_len, |
|
use_src_map, |
|
enc_final_hs, |
|
batch_size, |
|
src, |
|
) |
|
|
|
|
|
src_map = batch["src_map"] if use_src_map else None |
|
target_prefix = batch["tgt"] if self.tgt_file_prefix else None |
|
(fn_map_state, enc_out, src_len_tiled, src_map,) = decode_strategy.initialize( |
|
enc_out, src_len, src_map, target_prefix=target_prefix |
|
) |
|
|
|
if fn_map_state is not None: |
|
self.model.decoder.map_state(fn_map_state) |
|
|
|
|
|
for step in range(decode_strategy.max_length): |
|
|
|
|
|
decoder_input = decode_strategy.current_predictions.view(-1, 1, 1) |
|
|
|
log_probs, attn = self._decode_and_generate( |
|
decoder_input, |
|
enc_out, |
|
batch, |
|
src_len=src_len_tiled, |
|
src_map=src_map, |
|
step=step, |
|
batch_offset=decode_strategy.batch_offset, |
|
) |
|
|
|
decode_strategy.advance(log_probs, attn) |
|
any_finished = decode_strategy.is_finished.any() |
|
if any_finished: |
|
decode_strategy.update_finished() |
|
if decode_strategy.done: |
|
break |
|
|
|
select_indices = decode_strategy.select_indices |
|
|
|
if any_finished: |
|
|
|
if isinstance(enc_out, tuple): |
|
enc_out = tuple(x.index_select(0, select_indices) for x in enc_out) |
|
else: |
|
enc_out = enc_out.index_select(0, select_indices) |
|
|
|
src_len_tiled = src_len_tiled.index_select(0, select_indices) |
|
|
|
if src_map is not None: |
|
src_map = src_map.index_select(0, select_indices) |
|
|
|
if parallel_paths > 1 or any_finished: |
|
self.model.decoder.map_state( |
|
lambda state, dim: state.index_select(dim, select_indices) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.report_results( |
|
gold_score, |
|
batch, |
|
batch_size, |
|
src, |
|
src_len, |
|
use_src_map, |
|
decode_strategy, |
|
) |
|
|
|
def _score_target(self, batch, enc_out, src_len, src_map): |
|
tgt = batch["tgt"] |
|
tgt_in = tgt[:, :-1, :] |
|
|
|
log_probs, attn = self._decode_and_generate( |
|
tgt_in, |
|
enc_out, |
|
batch, |
|
src_len=src_len, |
|
src_map=src_map, |
|
) |
|
|
|
log_probs[:, :, self._tgt_pad_idx] = 0 |
|
gold = tgt[:, 1:, :] |
|
gold_scores = log_probs.gather(2, gold) |
|
gold_scores = gold_scores.sum(dim=1).view(-1) |
|
return gold_scores |
|
|
|
|
|
class GeneratorLM(Inference): |
|
@classmethod |
|
def validate_task(cls, task): |
|
if task != ModelTask.LANGUAGE_MODEL: |
|
raise ValueError( |
|
f"GeneratorLM does not support task {task}." |
|
f" Tasks supported: {ModelTask.LANGUAGE_MODEL}" |
|
) |
|
|
|
def _align_forward(self, batch, predictions): |
|
""" |
|
For a batch of input and its prediction, return a list of batch predict |
|
alignment src indice Tensor in size ``(batch, n_best,)``. |
|
""" |
|
raise NotImplementedError |
|
|
|
def translate_batch(self, batch, attn_debug): |
|
"""Translate a batch of sentences.""" |
|
batch_size = len(batch["srclen"]) |
|
if batch_size != 1: |
|
warning_msg = ( |
|
"GeneratorLM does not support batch_size != 1" |
|
" nicely. You can remove this limitation here." |
|
" With batch_size > 1 the end of each input is" |
|
" repeated until the input is finished. Then" |
|
" generation will start." |
|
) |
|
if self.logger: |
|
self.logger.info(warning_msg) |
|
else: |
|
os.write(1, warning_msg.encode("utf-8")) |
|
with torch.no_grad(): |
|
if self.sample_from_topk != 0 or self.sample_from_topp != 0: |
|
decode_strategy = GreedySearchLM( |
|
pad=self._tgt_pad_idx, |
|
bos=self._tgt_bos_idx, |
|
eos=self._tgt_eos_idx, |
|
unk=self._tgt_unk_idx, |
|
start=self._tgt_start_with, |
|
batch_size=len(batch["srclen"]), |
|
global_scorer=self.global_scorer, |
|
min_length=self.min_length, |
|
max_length=self.max_length, |
|
block_ngram_repeat=self.block_ngram_repeat, |
|
exclusion_tokens=self._exclusion_idxs, |
|
return_attention=attn_debug or self.replace_unk, |
|
sampling_temp=self.random_sampling_temp, |
|
keep_topk=self.sample_from_topk, |
|
keep_topp=self.sample_from_topp, |
|
beam_size=self.beam_size, |
|
ban_unk_token=self.ban_unk_token, |
|
) |
|
else: |
|
|
|
assert not self.dump_beam |
|
decode_strategy = BeamSearchLM( |
|
self.beam_size, |
|
batch_size=len(batch["srclen"]), |
|
pad=self._tgt_pad_idx, |
|
bos=self._tgt_bos_idx, |
|
eos=self._tgt_eos_idx, |
|
unk=self._tgt_unk_idx, |
|
start=self._tgt_start_with, |
|
n_best=self.n_best, |
|
global_scorer=self.global_scorer, |
|
min_length=self.min_length, |
|
max_length=self.max_length, |
|
return_attention=attn_debug or self.replace_unk, |
|
block_ngram_repeat=self.block_ngram_repeat, |
|
exclusion_tokens=self._exclusion_idxs, |
|
stepwise_penalty=self.stepwise_penalty, |
|
ratio=self.ratio, |
|
ban_unk_token=self.ban_unk_token, |
|
) |
|
return self._translate_batch_with_strategy(batch, decode_strategy) |
|
|
|
@classmethod |
|
def split_src_to_prevent_padding(cls, src, src_len): |
|
min_len_batch = torch.min(src_len).item() |
|
target_prefix = None |
|
if min_len_batch > 0 and min_len_batch < src.size(1): |
|
target_prefix = src[:, min_len_batch:, :] |
|
src = src[:, :min_len_batch, :] |
|
src_len[:] = min_len_batch |
|
return src, src_len, target_prefix |
|
|
|
def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs): |
|
if fn_map_state is not None: |
|
log_probs = fn_map_state(log_probs, dim=0) |
|
self.model.decoder.map_state(fn_map_state) |
|
log_probs = log_probs[:, -1, :] |
|
return log_probs |
|
|
|
def _translate_batch_with_strategy(self, batch, decode_strategy): |
|
"""Translate a batch of sentences step by step using cache. |
|
|
|
Args: |
|
batch: a batch of sentences, yield by data iterator. |
|
decode_strategy (DecodeStrategy): A decode strategy to use for |
|
generate translation step by step. |
|
|
|
Returns: |
|
results (dict): The translation results. |
|
""" |
|
|
|
use_src_map = self.copy_attn |
|
parallel_paths = decode_strategy.parallel_paths |
|
batch_size = len(batch["srclen"]) |
|
|
|
|
|
src = batch["src"] |
|
src_len = batch["srclen"] |
|
|
|
src, src_len, target_prefix = self.split_src_to_prevent_padding(src, src_len) |
|
|
|
|
|
self.model.decoder.init_state(src, None, None) |
|
gold_score = self._gold_score( |
|
batch, |
|
None, |
|
src_len, |
|
use_src_map, |
|
None, |
|
batch_size, |
|
src, |
|
) |
|
|
|
|
|
src_map = batch["src_map"] if use_src_map else None |
|
(fn_map_state, src, src_len_tiled, src_map,) = decode_strategy.initialize( |
|
src, |
|
src_len, |
|
src_map, |
|
target_prefix=target_prefix, |
|
) |
|
|
|
|
|
for step in range(decode_strategy.max_length): |
|
decoder_input = ( |
|
src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1) |
|
) |
|
|
|
log_probs, attn = self._decode_and_generate( |
|
decoder_input, |
|
None, |
|
batch, |
|
src_len=src_len_tiled.clone(), |
|
src_map=src_map, |
|
step=step if step == 0 else step + src_len[0].item(), |
|
batch_offset=decode_strategy.batch_offset, |
|
) |
|
|
|
if step == 0: |
|
log_probs = self.tile_to_beam_size_after_initial_step( |
|
fn_map_state, log_probs |
|
) |
|
|
|
decode_strategy.advance(log_probs, attn) |
|
any_finished = decode_strategy.is_finished.any() |
|
if any_finished: |
|
decode_strategy.update_finished() |
|
if decode_strategy.done: |
|
break |
|
|
|
select_indices = decode_strategy.select_indices |
|
src_len_tiled += 1 |
|
if any_finished: |
|
|
|
src_len_tiled = src_len_tiled.index_select(0, select_indices) |
|
|
|
if src_map is not None: |
|
src_map = src_map.index_select(0, select_indices) |
|
|
|
if parallel_paths > 1 or any_finished: |
|
|
|
self.model.decoder.map_state( |
|
lambda state, dim: state.index_select(dim, select_indices) |
|
) |
|
|
|
return self.report_results( |
|
gold_score, |
|
batch, |
|
batch_size, |
|
src, |
|
src_len, |
|
use_src_map, |
|
decode_strategy, |
|
) |
|
|
|
def _score_target(self, batch, enc_out, src_len, src_map): |
|
src = batch["src"] |
|
src_len = batch["srclen"] |
|
tgt = batch["tgt"] |
|
|
|
log_probs, attn = self._decode_and_generate( |
|
src, |
|
None, |
|
batch, |
|
src_len=src_len, |
|
src_map=src_map, |
|
) |
|
|
|
log_probs[:, :, self._tgt_pad_idx] = 0 |
|
gold_scores = log_probs.gather(2, tgt) |
|
gold_scores = gold_scores.sum(dim=1).view(-1) |
|
|
|
return gold_scores |
|
|