File size: 1,183 Bytes
befbc32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from layers.decoder_layer import DecoderLayer
class Decoder(nn.Module):
"""
A transformer Decoder (no embeddings or positional embeddings)
Args:
-
Outputs:
- (batch, seq_len, d_model): decoder output
- (batch, seq_len, seq_len): decoder attention
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout_p: int,
num_layers: int,
) -> None:
super(Decoder, self).__init__()
self.layers = nn.ModuleList(
[
DecoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
)
for _ in range(num_layers)
]
)
def forward(
self,
x: Tensor,
encoder_output: Tensor,
src_mask: Tensor,
tgt_mask: Tensor
) -> Tuple[Tensor, Tensor]:
for layer in self.layers:
x, attn = layer(x, encoder_output, src_mask, tgt_mask)
return x, attn |