homemade_lo_vi / translate.py
moiduy04's picture
Update translate.py
5d0c765
raw
history blame
2.71 kB
from typing import Tuple
import torch
from torch import Tensor
from tokenizers import Tokenizer
from transformer import Transformer
from decode_method import greedy_decode, beam_search_decode
def translate(
model: Transformer,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
text: str,
decode_method: str = 'greedy',
device = torch.device('cpu')
) -> Tuple[str, Tensor]:
"""
Translation function.
Supported `decode_method`: 'greedy' or 'beam-search'
'beam-search' doesn't give attn scores.
Output:
- translation (str): the translated string.
- attn (Tensor): The decoder's attention (for visualization)
"""
model.eval()
with torch.no_grad():
sos_token = torch.tensor([src_tokenizer.token_to_id('<sos>')], dtype=torch.int64)
eos_token = torch.tensor([src_tokenizer.token_to_id('<eos>')], dtype=torch.int64)
pad_token = torch.tensor([src_tokenizer.token_to_id('<pad>')], dtype=torch.int64)
encoder_input_tokens = src_tokenizer.encode(text).ids
# <sos> + source_text + <eos> = encoder_input
encoder_input = torch.cat(
[
sos_token,
torch.tensor(encoder_input_tokens, dtype=torch.int64),
eos_token,
]
)
encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
encoder_input = encoder_input.unsqueeze(0)
# encoder_mask = torch.tensor(encoder_mask)
assert encoder_input.size(0) == 1
if decode_method == 'greedy':
model_out, attn = greedy_decode(
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
give_attn=True,
)
elif decode_method == 'beam-search':
model_out = beam_search_decode(
model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device,
)
attn = None # Beam search doesn't give attention score
else:
raise ValueError("Unsuppored decode method")
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
return model_out_text, attn
from config import load_config
from load_model import load_model_tokenizer
if __name__ == '__main__':
config = load_config(file_name='/config/config_final.yaml')
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config)
text = "ສະບາຍດີ" # Hello.
translation, attn = translate(
model, src_tokenizer, tgt_tokenizer, text,
decode_method='beam-search',
)
print(translation)