moiduy04 commited on
Commit
befbc32
1 Parent(s): e5ce7af

Upload decoder.py

Browse files
Files changed (1) hide show
  1. model/decoder.py +49 -0
model/decoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+
6
+ from layers.decoder_layer import DecoderLayer
7
+
8
+ class Decoder(nn.Module):
9
+ """
10
+ A transformer Decoder (no embeddings or positional embeddings)
11
+
12
+ Args:
13
+ -
14
+
15
+ Outputs:
16
+ - (batch, seq_len, d_model): decoder output
17
+ - (batch, seq_len, seq_len): decoder attention
18
+ """
19
+ def __init__(
20
+ self,
21
+ d_model: int,
22
+ num_heads: int,
23
+ d_ff: int,
24
+ dropout_p: int,
25
+ num_layers: int,
26
+ ) -> None:
27
+ super(Decoder, self).__init__()
28
+ self.layers = nn.ModuleList(
29
+ [
30
+ DecoderLayer(
31
+ d_model=d_model,
32
+ num_heads=num_heads,
33
+ d_ff=d_ff,
34
+ dropout_p=dropout_p,
35
+ )
36
+ for _ in range(num_layers)
37
+ ]
38
+ )
39
+
40
+ def forward(
41
+ self,
42
+ x: Tensor,
43
+ encoder_output: Tensor,
44
+ src_mask: Tensor,
45
+ tgt_mask: Tensor
46
+ ) -> Tuple[Tensor, Tensor]:
47
+ for layer in self.layers:
48
+ x, attn = layer(x, encoder_output, src_mask, tgt_mask)
49
+ return x, attn