|
import torch |
|
from torch import Tensor |
|
|
|
from transformer import Transformer |
|
from tokenizers import Tokenizer |
|
from dataset import causal_mask |
|
|
|
|
|
def greedy_decode( |
|
model: Transformer, |
|
src: Tensor, |
|
src_mask: Tensor, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
tgt_max_seq_len: int, |
|
device, |
|
give_attn: bool = False, |
|
): |
|
""" |
|
Decodes greedily. |
|
""" |
|
sos_idx = src_tokenizer.token_to_id('<sos>') |
|
eos_idx = src_tokenizer.token_to_id('<eos>') |
|
|
|
encoder_output = model.encode(src, src_mask) |
|
|
|
attn = None |
|
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) |
|
|
|
while True: |
|
if decoder_input.size(1) == tgt_max_seq_len: |
|
break |
|
|
|
|
|
decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device) |
|
|
|
|
|
decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask) |
|
|
|
prob = model.project(decoder_output[:, -1]) |
|
_, next_word = torch.max(prob, dim=1) |
|
decoder_input = torch.cat( |
|
[decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1 |
|
) |
|
|
|
if next_word == eos_idx: |
|
break |
|
if give_attn: |
|
return (decoder_input.squeeze(0), attn) |
|
return decoder_input.squeeze(0) |
|
|
|
def beam_search_decode( |
|
model: Transformer, |
|
src: Tensor, |
|
src_mask: Tensor, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
tgt_max_seq_len: int, |
|
device, |
|
beam_size: int = 3, |
|
): |
|
sos_idx = src_tokenizer.token_to_id('<sos>') |
|
eos_idx = src_tokenizer.token_to_id('<eos>') |
|
|
|
|
|
encoder_output = model.encode(src, src_mask) |
|
|
|
decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) |
|
|
|
|
|
candidates = [(decoder_initial_input, 1)] |
|
|
|
while True: |
|
|
|
|
|
if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]): |
|
break |
|
|
|
|
|
new_candidates = [] |
|
|
|
for candidate, score in candidates: |
|
|
|
|
|
if candidate[0][-1].item() == eos_idx: |
|
continue |
|
|
|
|
|
candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device) |
|
|
|
out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask) |
|
|
|
prob = model.project(out[:, -1]) |
|
|
|
topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1) |
|
for i in range(beam_size): |
|
|
|
token = topk_idx[0][i].unsqueeze(0).unsqueeze(0) |
|
token_prob = topk_prob[0][i].item() |
|
|
|
new_candidate = torch.cat([candidate, token], dim=1) |
|
|
|
new_candidates.append((new_candidate, score + token_prob)) |
|
|
|
|
|
candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True) |
|
|
|
candidates = candidates[:beam_size] |
|
|
|
|
|
if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]): |
|
break |
|
|
|
|
|
return candidates[0][0].squeeze() |