|
import torch.nn as nn |
|
from torch import Tensor |
|
import numpy as np |
|
|
|
|
|
class TransformerEmbedding(nn.Module): |
|
""" |
|
Input Embeddings (section 3.4) |
|
|
|
Embedds words to vectors of size d_ |
|
|
|
Args: |
|
- d_model (int): dimension of model |
|
- num_embeddings (int): size of the dictionary |
|
""" |
|
def __init__(self, d_model: int, num_embeddings: int) -> None: |
|
super(TransformerEmbedding, self).__init__() |
|
self.sqrt_d_model = np.sqrt(d_model) |
|
self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=d_model) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.embedding(x) * self.sqrt_d_model |
|
|