|
from typing import Tuple |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from layers.decoder_layer import DecoderLayer |
|
|
|
class Decoder(nn.Module): |
|
""" |
|
A transformer Decoder (no embeddings or positional embeddings) |
|
|
|
Args: |
|
- |
|
|
|
Outputs: |
|
- (batch, seq_len, d_model): decoder output |
|
- (batch, seq_len, seq_len): decoder attention |
|
""" |
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
d_ff: int, |
|
dropout_p: int, |
|
num_layers: int, |
|
) -> None: |
|
super(Decoder, self).__init__() |
|
self.layers = nn.ModuleList( |
|
[ |
|
DecoderLayer( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_ff=d_ff, |
|
dropout_p=dropout_p, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
encoder_output: Tensor, |
|
src_mask: Tensor, |
|
tgt_mask: Tensor |
|
) -> Tuple[Tensor, Tensor]: |
|
for layer in self.layers: |
|
x, attn = layer(x, encoder_output, src_mask, tgt_mask) |
|
return x, attn |