moiduy04 commited on
Commit
27f7f75
1 Parent(s): 7ac9d34

Upload 4 files

Browse files
Files changed (4) hide show
  1. decode_method.py +50 -0
  2. requirements.txt +8 -0
  3. tokenizer_lo.json +0 -0
  4. tokenizer_vi.json +0 -0
decode_method.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from transformer import Transformer
5
+ from tokenizers import Tokenizer
6
+ from dataset import causal_mask
7
+
8
+
9
+ def greedy_decode(
10
+ model: Transformer,
11
+ src: Tensor,
12
+ src_mask: Tensor,
13
+ src_tokenizer: Tokenizer,
14
+ tgt_tokenizer: Tokenizer,
15
+ tgt_max_seq_len: int,
16
+ device,
17
+ give_attn: bool = False,
18
+ ):
19
+ """
20
+ Decodes greedily.
21
+ """
22
+ sos_idx = src_tokenizer.token_to_id('<sos>')
23
+ eos_idx = src_tokenizer.token_to_id('<eos>')
24
+
25
+ encoder_output = model.encode(src, src_mask)
26
+
27
+ attn = None
28
+ decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
29
+
30
+ while True:
31
+ if decoder_input.size(1) == tgt_max_seq_len:
32
+ break
33
+
34
+ # build target mask
35
+ decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device)
36
+
37
+ # get decoder output
38
+ decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)
39
+
40
+ prob = model.project(decoder_output[:, -1])
41
+ _, next_word = torch.max(prob, dim=1)
42
+ decoder_input = torch.cat(
43
+ [decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1
44
+ )
45
+
46
+ if next_word == eos_idx:
47
+ break
48
+ if give_attn:
49
+ return (decoder_input.squeeze(0), attn)
50
+ return decoder_input.squeeze(0)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ torchmetrics
6
+ tokenizer
7
+ transformers
8
+ tqdm
tokenizer_lo.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_vi.json ADDED
The diff for this file is too large to render. See raw diff