from typing import Tuple import torch import torch.nn as nn from torch import Tensor from modules.wrapper import Linear class ProjectionLayer(nn.Module): def __init__(self, d_model: int, vocab_size: int) -> None: super(ProjectionLayer, self).__init__() self.linear = Linear(d_model, vocab_size) def forward(self, x): # (batch, seq_len, d_model) -> (batch, seq_len, vocab_size) return torch.log_softmax(self.linear(x), dim=-1)