File size: 2,710 Bytes
bc1ada8
 
 
 
 
 
 
 
b8a6dde
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
b8a6dde
 
 
 
bc1ada8
 
 
 
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc1ada8
b8a6dde
 
 
 
bc1ada8
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
bc1ada8
b8a6dde
 
bc1ada8
 
 
5d0c765
bc1ada8
5d0c765
bc1ada8
 
 
b8a6dde
 
bc1ada8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Tuple

import torch
from torch import Tensor

from tokenizers import Tokenizer

from transformer import Transformer
from decode_method import greedy_decode, beam_search_decode

def translate(
    model: Transformer,
    src_tokenizer: Tokenizer,
    tgt_tokenizer: Tokenizer,
    text: str,
    decode_method: str = 'greedy',
    device = torch.device('cpu')
) -> Tuple[str, Tensor]:
    """
    Translation function.

    Supported `decode_method`: 'greedy' or 'beam-search'
    
    'beam-search' doesn't give attn scores.

    Output:
        - translation (str): the translated string.
        - attn (Tensor): The decoder's attention (for visualization)
    """
    model.eval()
    with torch.no_grad():
        sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
        eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
        pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
        
        encoder_input_tokens = src_tokenizer.encode(text).ids
        # <sos> + source_text + <eos> = encoder_input
        encoder_input = torch.cat(
            [
                sos_token,
                torch.tensor(encoder_input_tokens, dtype=torch.int64),
                eos_token,
            ]
        )
        encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)

        encoder_input = encoder_input.unsqueeze(0)
        # encoder_mask = torch.tensor(encoder_mask)
        
        assert encoder_input.size(0) == 1

        if decode_method == 'greedy':
            model_out, attn = greedy_decode(
                model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, 
                give_attn=True,
            )
        elif decode_method == 'beam-search':
            model_out = beam_search_decode(
                model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, 
            )
            attn = None # Beam search doesn't give attention score
        else:
            raise ValueError("Unsuppored decode method")

        model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
        return model_out_text, attn
    

from config import load_config
from load_model import load_model_tokenizer
if __name__ == '__main__':
    config = load_config(file_name='/config/config_final.yaml')
    model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
    text = "ສະບາຍດີ" # Hello.
    translation, attn = translate(
        model, src_tokenizer, tgt_tokenizer, text,
        decode_method='beam-search',
    )
    print(translation)