homemade_lo_vi / modules /transformer_embedding.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
672 Bytes
import torch.nn as nn
from torch import Tensor
import numpy as np
class TransformerEmbedding(nn.Module):
"""
Input Embeddings (section 3.4)
Embedds words to vectors of size d_
Args:
- d_model (int): dimension of model
- num_embeddings (int): size of the dictionary
"""
def __init__(self, d_model: int, num_embeddings: int) -> None:
super(TransformerEmbedding, self).__init__()
self.sqrt_d_model = np.sqrt(d_model)
self.embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=d_model)
def forward(self, x: Tensor) -> Tensor:
return self.embedding(x) * self.sqrt_d_model