|
from typing import Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
from tokenizers import Tokenizer |
|
|
|
from transformer import Transformer |
|
from decode_method import greedy_decode, beam_search_decode |
|
|
|
def translate( |
|
model: Transformer, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
text: str, |
|
decode_method: str = 'greedy', |
|
device = torch.device('cpu') |
|
) -> Tuple[str, Tensor]: |
|
""" |
|
Translation function. |
|
|
|
Supported `decode_method`: 'greedy' or 'beam-search' |
|
|
|
'beam-search' doesn't give attn scores. |
|
|
|
Output: |
|
- translation (str): the translated string. |
|
- attn (Tensor): The decoder's attention (for visualization) |
|
""" |
|
model.eval() |
|
with torch.no_grad(): |
|
sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64) |
|
eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64) |
|
pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64) |
|
|
|
encoder_input_tokens = src_tokenizer.encode(text).ids |
|
|
|
encoder_input = torch.cat( |
|
[ |
|
sos_token, |
|
torch.tensor(encoder_input_tokens, dtype=torch.int64), |
|
eos_token, |
|
] |
|
) |
|
encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() |
|
|
|
encoder_input = encoder_input.unsqueeze(0) |
|
|
|
|
|
assert encoder_input.size(0) == 1 |
|
|
|
if decode_method == 'greedy': |
|
model_out, attn = greedy_decode( |
|
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, |
|
give_attn=True, |
|
) |
|
elif decode_method == 'beam-search': |
|
model_out = beam_search_decode( |
|
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, |
|
) |
|
attn = None |
|
else: |
|
raise ValueError("Unsuppored decode method") |
|
|
|
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) |
|
return model_out_text, attn |
|
|
|
|
|
from config import load_config |
|
from load_model import load_model_tokenizer |
|
if __name__ == '__main__': |
|
config = load_config(file_name='/config/config_final.yaml') |
|
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config) |
|
text = "ສະບາຍດີ" |
|
translation, attn = translate( |
|
model, src_tokenizer, tgt_tokenizer, text, |
|
decode_method='beam-search', |
|
) |
|
print(translation) |