Upload decoder.py
Browse files- model/decoder.py +49 -0
model/decoder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
from layers.decoder_layer import DecoderLayer
|
7 |
+
|
8 |
+
class Decoder(nn.Module):
|
9 |
+
"""
|
10 |
+
A transformer Decoder (no embeddings or positional embeddings)
|
11 |
+
|
12 |
+
Args:
|
13 |
+
-
|
14 |
+
|
15 |
+
Outputs:
|
16 |
+
- (batch, seq_len, d_model): decoder output
|
17 |
+
- (batch, seq_len, seq_len): decoder attention
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
d_model: int,
|
22 |
+
num_heads: int,
|
23 |
+
d_ff: int,
|
24 |
+
dropout_p: int,
|
25 |
+
num_layers: int,
|
26 |
+
) -> None:
|
27 |
+
super(Decoder, self).__init__()
|
28 |
+
self.layers = nn.ModuleList(
|
29 |
+
[
|
30 |
+
DecoderLayer(
|
31 |
+
d_model=d_model,
|
32 |
+
num_heads=num_heads,
|
33 |
+
d_ff=d_ff,
|
34 |
+
dropout_p=dropout_p,
|
35 |
+
)
|
36 |
+
for _ in range(num_layers)
|
37 |
+
]
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(
|
41 |
+
self,
|
42 |
+
x: Tensor,
|
43 |
+
encoder_output: Tensor,
|
44 |
+
src_mask: Tensor,
|
45 |
+
tgt_mask: Tensor
|
46 |
+
) -> Tuple[Tensor, Tensor]:
|
47 |
+
for layer in self.layers:
|
48 |
+
x, attn = layer(x, encoder_output, src_mask, tgt_mask)
|
49 |
+
return x, attn
|