import torch.nn as nn from torch import Tensor import numpy as np class TransformerEmbedding(nn.Module): """ Input Embeddings (section 3.4) Embedds words to vectors of size d_ Args: - d_model (int): dimension of model - num_embeddings (int): size of the dictionary """ def __init__(self, d_model: int, num_embeddings: int) -> None: super(TransformerEmbedding, self).__init__() self.sqrt_d_model = np.sqrt(d_model) self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=d_model) def forward(self, x: Tensor) -> Tensor: return self.embedding(x) * self.sqrt_d_model