File size: 3,991 Bytes
27f7f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch import Tensor

from transformer import Transformer
from tokenizers import Tokenizer
from dataset import causal_mask


def greedy_decode(
    model: Transformer, 
    src: Tensor, 
    src_mask: Tensor, 
    src_tokenizer: Tokenizer, 
    tgt_tokenizer: Tokenizer, 
    tgt_max_seq_len: int, 
    device,
    give_attn: bool = False,
):
    """
    Decodes greedily.
    """
    sos_idx = src_tokenizer.token_to_id('<sos>')
    eos_idx = src_tokenizer.token_to_id('<eos>')

    encoder_output = model.encode(src, src_mask)

    attn = None
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
    
    while True:
        if decoder_input.size(1) == tgt_max_seq_len:
            break

        # build target mask
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device)

        # get decoder output
        decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)

        prob = model.project(decoder_output[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break
    if give_attn:
        return (decoder_input.squeeze(0), attn)
    return decoder_input.squeeze(0)

def beam_search_decode(
    model: Transformer, 
    src: Tensor, 
    src_mask: Tensor, 
    src_tokenizer: Tokenizer, 
    tgt_tokenizer: Tokenizer, 
    tgt_max_seq_len: int, 
    device,
    beam_size: int = 3,
):
    sos_idx = src_tokenizer.token_to_id('<sos>')
    eos_idx = src_tokenizer.token_to_id('<eos>')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(src, src_mask)
    # Initialize the decoder input with the sos token
    decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)

    # Create a candidate list
    candidates = [(decoder_initial_input, 1)]

    while True:

        # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
        if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]):
            break

        # Create a new list of candidates
        new_candidates = []

        for candidate, score in candidates:

            # Do not expand candidates that have reached the eos token
            if candidate[0][-1].item() == eos_idx:
                continue

            # Build the candidate's mask
            candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device)
            # calculate output
            out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask)
            # get next token probabilities
            prob = model.project(out[:, -1])
            # get the top k candidates
            topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
            for i in range(beam_size):
                # for each of the top k candidates, get the token and its probability
                token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
                token_prob = topk_prob[0][i].item()
                # create a new candidate by appending the token to the current candidate
                new_candidate = torch.cat([candidate, token], dim=1)
                # We sum the log probabilities because the probabilities are in log space
                new_candidates.append((new_candidate, score + token_prob))

        # Sort the new candidates by their score
        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
        # Keep only the top k candidates
        candidates = candidates[:beam_size]

        # If all the candidates have reached the eos token, stop
        if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
            break

    # Return the best candidate
    return candidates[0][0].squeeze()