import torch from torch import Tensor from transformer import Transformer from tokenizers import Tokenizer from dataset import causal_mask def greedy_decode( model: Transformer, src: Tensor, src_mask: Tensor, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, tgt_max_seq_len: int, device, give_attn: bool = False, ): """ Decodes greedily. """ sos_idx = src_tokenizer.token_to_id('') eos_idx = src_tokenizer.token_to_id('') encoder_output = model.encode(src, src_mask) attn = None decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) while True: if decoder_input.size(1) == tgt_max_seq_len: break # build target mask decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device) # get decoder output decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask) prob = model.project(decoder_output[:, -1]) _, next_word = torch.max(prob, dim=1) decoder_input = torch.cat( [decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1 ) if next_word == eos_idx: break if give_attn: return (decoder_input.squeeze(0), attn) return decoder_input.squeeze(0) def beam_search_decode( model: Transformer, src: Tensor, src_mask: Tensor, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, tgt_max_seq_len: int, device, beam_size: int = 3, ): sos_idx = src_tokenizer.token_to_id('') eos_idx = src_tokenizer.token_to_id('') # Precompute the encoder output and reuse it for every step encoder_output = model.encode(src, src_mask) # Initialize the decoder input with the sos token decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device) # Create a candidate list candidates = [(decoder_initial_input, 1)] while True: # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]): break # Create a new list of candidates new_candidates = [] for candidate, score in candidates: # Do not expand candidates that have reached the eos token if candidate[0][-1].item() == eos_idx: continue # Build the candidate's mask candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device) # calculate output out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask) # get next token probabilities prob = model.project(out[:, -1]) # get the top k candidates topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1) for i in range(beam_size): # for each of the top k candidates, get the token and its probability token = topk_idx[0][i].unsqueeze(0).unsqueeze(0) token_prob = topk_prob[0][i].item() # create a new candidate by appending the token to the current candidate new_candidate = torch.cat([candidate, token], dim=1) # We sum the log probabilities because the probabilities are in log space new_candidates.append((new_candidate, score + token_prob)) # Sort the new candidates by their score candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True) # Keep only the top k candidates candidates = candidates[:beam_size] # If all the candidates have reached the eos token, stop if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]): break # Return the best candidate return candidates[0][0].squeeze()