|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq import utils |
|
|
|
|
|
DecoderOut = namedtuple( |
|
"IterativeRefinementDecoderOut", |
|
["output_tokens", "output_scores", "attn", "step", "max_step", "history"], |
|
) |
|
|
|
|
|
class IterativeRefinementGenerator(object): |
|
def __init__( |
|
self, |
|
tgt_dict, |
|
models=None, |
|
eos_penalty=0.0, |
|
max_iter=10, |
|
max_ratio=2, |
|
beam_size=1, |
|
decoding_format=None, |
|
retain_dropout=False, |
|
adaptive=True, |
|
retain_history=False, |
|
reranking=False, |
|
): |
|
""" |
|
Generates translations based on iterative refinement. |
|
|
|
Args: |
|
tgt_dict: target dictionary |
|
eos_penalty: if > 0.0, it penalized early-stopping in decoding |
|
max_iter: maximum number of refinement iterations |
|
max_ratio: generate sequences of maximum length ax, where x is the source length |
|
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} |
|
retain_dropout: retaining dropout in the inference |
|
adaptive: decoding with early stop |
|
""" |
|
self.bos = tgt_dict.bos() |
|
self.pad = tgt_dict.pad() |
|
self.unk = tgt_dict.unk() |
|
self.eos = tgt_dict.eos() |
|
self.vocab_size = len(tgt_dict) |
|
self.eos_penalty = eos_penalty |
|
self.max_iter = max_iter |
|
self.max_ratio = max_ratio |
|
self.beam_size = beam_size |
|
self.reranking = reranking |
|
self.decoding_format = decoding_format |
|
self.retain_dropout = retain_dropout |
|
self.retain_history = retain_history |
|
self.adaptive = adaptive |
|
self.models = models |
|
|
|
def generate_batched_itr( |
|
self, |
|
data_itr, |
|
maxlen_a=None, |
|
maxlen_b=None, |
|
cuda=False, |
|
timer=None, |
|
prefix_size=0, |
|
): |
|
"""Iterate over a batched dataset and yield individual translations. |
|
|
|
Args: |
|
maxlen_a/b: generate sequences of maximum length ax + b, |
|
where x is the source sentence length. |
|
cuda: use GPU for generation |
|
timer: StopwatchMeter for timing generations. |
|
""" |
|
|
|
for sample in data_itr: |
|
if "net_input" not in sample: |
|
continue |
|
if timer is not None: |
|
timer.start() |
|
with torch.no_grad(): |
|
hypos = self.generate( |
|
self.models, |
|
sample, |
|
prefix_tokens=sample["target"][:, :prefix_size] |
|
if prefix_size > 0 |
|
else None, |
|
) |
|
if timer is not None: |
|
timer.stop(sample["ntokens"]) |
|
for i, id in enumerate(sample["id"]): |
|
|
|
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) |
|
ref = utils.strip_pad(sample["target"][i, :], self.pad) |
|
yield id, src, ref, hypos[i] |
|
|
|
@torch.no_grad() |
|
def generate(self, models, sample, prefix_tokens=None, constraints=None): |
|
if constraints is not None: |
|
raise NotImplementedError( |
|
"Constrained decoding with the IterativeRefinementGenerator is not supported" |
|
) |
|
|
|
|
|
if not self.retain_dropout: |
|
for model in models: |
|
model.eval() |
|
|
|
model, reranker = models[0], None |
|
if self.reranking: |
|
assert len(models) > 1, "Assuming the last checkpoint is the reranker" |
|
assert ( |
|
self.beam_size > 1 |
|
), "Reranking requires multiple translation for each example" |
|
|
|
reranker = models[-1] |
|
models = models[:-1] |
|
|
|
if len(models) > 1 and hasattr(model, "enable_ensemble"): |
|
assert model.allow_ensemble, "{} does not support ensembling".format( |
|
model.__class__.__name__ |
|
) |
|
model.enable_ensemble(models) |
|
|
|
|
|
src_tokens = sample["net_input"]["src_tokens"] |
|
src_lengths = sample["net_input"]["src_lengths"] |
|
bsz, src_len = src_tokens.size() |
|
|
|
|
|
encoder_out = model.forward_encoder([src_tokens, src_lengths]) |
|
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) |
|
|
|
if self.beam_size > 1: |
|
assert ( |
|
model.allow_length_beam |
|
), "{} does not support decoding with length beam.".format( |
|
model.__class__.__name__ |
|
) |
|
|
|
|
|
length_beam_order = ( |
|
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) |
|
) |
|
encoder_out = model.encoder.reorder_encoder_out( |
|
encoder_out, length_beam_order |
|
) |
|
prev_decoder_out = model.regenerate_length_beam( |
|
prev_decoder_out, self.beam_size |
|
) |
|
bsz = bsz * self.beam_size |
|
|
|
sent_idxs = torch.arange(bsz) |
|
prev_output_tokens = prev_decoder_out.output_tokens.clone() |
|
|
|
if self.retain_history: |
|
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) |
|
|
|
finalized = [[] for _ in range(bsz)] |
|
|
|
def is_a_loop(x, y, s, a): |
|
b, l_x, l_y = x.size(0), x.size(1), y.size(1) |
|
if l_x > l_y: |
|
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) |
|
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) |
|
if a is not None: |
|
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) |
|
elif l_x < l_y: |
|
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) |
|
return (x == y).all(1), y, s, a |
|
|
|
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): |
|
cutoff = prev_out_token.ne(self.pad) |
|
tokens = prev_out_token[cutoff] |
|
if prev_out_score is None: |
|
scores, score = None, None |
|
else: |
|
scores = prev_out_score[cutoff] |
|
score = scores.mean() |
|
|
|
if prev_out_attn is None: |
|
hypo_attn, alignment = None, None |
|
else: |
|
hypo_attn = prev_out_attn[cutoff] |
|
alignment = hypo_attn.max(dim=1)[1] |
|
return { |
|
"steps": step, |
|
"tokens": tokens, |
|
"positional_scores": scores, |
|
"score": score, |
|
"hypo_attn": hypo_attn, |
|
"alignment": alignment, |
|
} |
|
|
|
for step in range(self.max_iter + 1): |
|
|
|
decoder_options = { |
|
"eos_penalty": self.eos_penalty, |
|
"max_ratio": self.max_ratio, |
|
"decoding_format": self.decoding_format, |
|
} |
|
prev_decoder_out = prev_decoder_out._replace( |
|
step=step, |
|
max_step=self.max_iter + 1, |
|
) |
|
|
|
decoder_out = model.forward_decoder( |
|
prev_decoder_out, encoder_out, **decoder_options |
|
) |
|
|
|
if self.adaptive: |
|
|
|
terminated, out_tokens, out_scores, out_attn = is_a_loop( |
|
prev_output_tokens, |
|
decoder_out.output_tokens, |
|
decoder_out.output_scores, |
|
decoder_out.attn, |
|
) |
|
decoder_out = decoder_out._replace( |
|
output_tokens=out_tokens, |
|
output_scores=out_scores, |
|
attn=out_attn, |
|
) |
|
|
|
else: |
|
terminated = decoder_out.output_tokens.new_zeros( |
|
decoder_out.output_tokens.size(0) |
|
).bool() |
|
|
|
if step == self.max_iter: |
|
terminated.fill_(1) |
|
|
|
|
|
finalized_idxs = sent_idxs[terminated] |
|
finalized_tokens = decoder_out.output_tokens[terminated] |
|
finalized_scores = decoder_out.output_scores[terminated] |
|
finalized_attn = ( |
|
None |
|
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) |
|
else decoder_out.attn[terminated] |
|
) |
|
|
|
if self.retain_history: |
|
finalized_history_tokens = [h[terminated] for h in decoder_out.history] |
|
|
|
for i in range(finalized_idxs.size(0)): |
|
finalized[finalized_idxs[i]] = [ |
|
finalized_hypos( |
|
step, |
|
finalized_tokens[i], |
|
finalized_scores[i], |
|
None if finalized_attn is None else finalized_attn[i], |
|
) |
|
] |
|
|
|
if self.retain_history: |
|
finalized[finalized_idxs[i]][0]["history"] = [] |
|
for j in range(len(finalized_history_tokens)): |
|
finalized[finalized_idxs[i]][0]["history"].append( |
|
finalized_hypos( |
|
step, finalized_history_tokens[j][i], None, None |
|
) |
|
) |
|
|
|
|
|
if terminated.sum() == terminated.size(0): |
|
break |
|
|
|
|
|
not_terminated = ~terminated |
|
prev_decoder_out = decoder_out._replace( |
|
output_tokens=decoder_out.output_tokens[not_terminated], |
|
output_scores=decoder_out.output_scores[not_terminated], |
|
attn=decoder_out.attn[not_terminated] |
|
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) |
|
else None, |
|
history=[h[not_terminated] for h in decoder_out.history] |
|
if decoder_out.history is not None |
|
else None, |
|
) |
|
encoder_out = model.encoder.reorder_encoder_out( |
|
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() |
|
) |
|
sent_idxs = sent_idxs[not_terminated] |
|
prev_output_tokens = prev_decoder_out.output_tokens.clone() |
|
|
|
if self.beam_size > 1: |
|
if reranker is not None: |
|
finalized = self.rerank( |
|
reranker, finalized, [src_tokens, src_lengths], self.beam_size |
|
) |
|
|
|
|
|
finalized = [ |
|
finalized[ |
|
np.argmax( |
|
[ |
|
finalized[self.beam_size * i + j][0]["score"] |
|
for j in range(self.beam_size) |
|
] |
|
) |
|
+ self.beam_size * i |
|
] |
|
for i in range(len(finalized) // self.beam_size) |
|
] |
|
|
|
return finalized |
|
|
|
def rerank(self, reranker, finalized, encoder_input, beam_size): |
|
def rebuild_batch(finalized): |
|
finalized_tokens = [f[0]["tokens"] for f in finalized] |
|
finalized_maxlen = max(f.size(0) for f in finalized_tokens) |
|
final_output_tokens = ( |
|
finalized_tokens[0] |
|
.new_zeros(len(finalized_tokens), finalized_maxlen) |
|
.fill_(self.pad) |
|
) |
|
for i, f in enumerate(finalized_tokens): |
|
final_output_tokens[i, : f.size(0)] = f |
|
return final_output_tokens |
|
|
|
final_output_tokens = rebuild_batch(finalized) |
|
final_output_tokens[ |
|
:, 0 |
|
] = self.eos |
|
|
|
reranker_encoder_out = reranker.encoder(*encoder_input) |
|
length_beam_order = ( |
|
utils.new_arange( |
|
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) |
|
) |
|
.t() |
|
.reshape(-1) |
|
) |
|
reranker_encoder_out = reranker.encoder.reorder_encoder_out( |
|
reranker_encoder_out, length_beam_order |
|
) |
|
reranking_scores = reranker.get_normalized_probs( |
|
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), |
|
True, |
|
None, |
|
) |
|
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) |
|
reranking_masks = final_output_tokens[:, 1:].ne(self.pad) |
|
reranking_scores = ( |
|
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) |
|
) |
|
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( |
|
reranking_scores |
|
) |
|
|
|
for i in range(len(finalized)): |
|
finalized[i][0]["score"] = reranking_scores[i] |
|
|
|
return finalized |
|
|