from typing import Tuple import torch from torch import Tensor from tokenizers import Tokenizer from transformer import Transformer from decode_method import greedy_decode, beam_search_decode def translate( model: Transformer, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, text: str, decode_method: str = 'greedy', device = torch.device('cpu') ) -> Tuple[str, Tensor]: """ Translation function. Supported `decode_method`: 'greedy' or 'beam-search' 'beam-search' doesn't give attn scores. Output: - translation (str): the translated string. - attn (Tensor): The decoder's attention (for visualization) """ model.eval() with torch.no_grad(): sos_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) eos_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) pad_token = torch.tensor([src_tokenizer.token_to_id('')], dtype=torch.int64) encoder_input_tokens = src_tokenizer.encode(text).ids # + source_text + = encoder_input encoder_input = torch.cat( [ sos_token, torch.tensor(encoder_input_tokens, dtype=torch.int64), eos_token, ] ) encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len) encoder_input = encoder_input.unsqueeze(0) # encoder_mask = torch.tensor(encoder_mask) assert encoder_input.size(0) == 1 if decode_method == 'greedy': model_out, attn = greedy_decode( model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, give_attn=True, ) elif decode_method == 'beam-search': model_out = beam_search_decode( model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 349, device, ) attn = None # Beam search doesn't give attention score else: raise ValueError("Unsuppored decode method") model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) return model_out_text, attn from config import load_config from load_model import load_model_tokenizer if __name__ == '__main__': config = load_config(file_name='/config/config_final.yaml') model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config) text = "ສະບາຍດີ" # Hello. translation, attn = translate( model, src_tokenizer, tgt_tokenizer, text, decode_method='beam-search', ) print(translation)