homemade_lo_vi / decode_method.py
moiduy04's picture
Upload 12 files
b8a6dde
import torch
from torch import Tensor
from transformer import Transformer
from tokenizers import Tokenizer
from dataset import causal_mask
def greedy_decode(
model: Transformer,
src: Tensor,
src_mask: Tensor,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
tgt_max_seq_len: int,
device,
give_attn: bool = False,
):
"""
Decodes greedily.
"""
sos_idx = src_tokenizer.token_to_id('<sos>')
eos_idx = src_tokenizer.token_to_id('<eos>')
encoder_output = model.encode(src, src_mask)
attn = None
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
while True:
if decoder_input.size(1) == tgt_max_seq_len:
break
# build target mask
decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device)
# get decoder output
decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)
prob = model.project(decoder_output[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1
)
if next_word == eos_idx:
break
if give_attn:
return (decoder_input.squeeze(0), attn)
return decoder_input.squeeze(0)
def beam_search_decode(
model: Transformer,
src: Tensor,
src_mask: Tensor,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
tgt_max_seq_len: int,
device,
beam_size: int = 3,
):
sos_idx = src_tokenizer.token_to_id('<sos>')
eos_idx = src_tokenizer.token_to_id('<eos>')
# Precompute the encoder output and reuse it for every step
encoder_output = model.encode(src, src_mask)
# Initialize the decoder input with the sos token
decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
# Create a candidate list
candidates = [(decoder_initial_input, 1)]
while True:
# If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]):
break
# Create a new list of candidates
new_candidates = []
for candidate, score in candidates:
# Do not expand candidates that have reached the eos token
if candidate[0][-1].item() == eos_idx:
continue
# Build the candidate's mask
candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device)
# calculate output
out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask)
# get next token probabilities
prob = model.project(out[:, -1])
# get the top k candidates
topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
for i in range(beam_size):
# for each of the top k candidates, get the token and its probability
token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
token_prob = topk_prob[0][i].item()
# create a new candidate by appending the token to the current candidate
new_candidate = torch.cat([candidate, token], dim=1)
# We sum the log probabilities because the probabilities are in log space
new_candidates.append((new_candidate, score + token_prob))
# Sort the new candidates by their score
candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
# Keep only the top k candidates
candidates = candidates[:beam_size]
# If all the candidates have reached the eos token, stop
if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
break
# Return the best candidate
return candidates[0][0].squeeze()