import torch from torch import Tensor from transformer import Transformer from tokenizers import Tokenizer from dataset import causal_mask def greedy_decode( model: Transformer, src: Tensor, src_mask: Tensor, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, tgt_max_seq_len: int, device, give_attn: bool = False, ): """ Decodes greedily. """ sos_idx = src_tokenizer.token_to_id('') eos_idx = src_tokenizer.token_to_id('') encoder_output = model.encode(src, src_mask) attn = None decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) while True: if decoder_input.size(1) == tgt_max_seq_len: break # build target mask decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device) # get decoder output decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask) prob = model.project(decoder_output[:, -1]) _, next_word = torch.max(prob, dim=1) decoder_input = torch.cat( [decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1 ) if next_word == eos_idx: break if give_attn: return (decoder_input.squeeze(0), attn) return decoder_input.squeeze(0)