szukevin's picture
upload
7900c16
raw
history blame
780 Bytes
import math
import torch.nn as nn
class WordEmbedding(nn.Module):
"""
"""
def __init__(self, args, vocab_size):
super(WordEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, args.emb_size)
self.emb_size = args.emb_size
self.sinusoidalpos = False
if "sinusoidalpos" in args.embedding:
self.sinusoidalpos = True
def forward(self, src, _):
"""
Args:
src: [batch_size x seq_length]
seg: [batch_size x seq_length]
Returns:
emb: [batch_size x seq_length x hidden_size]
"""
emb = self.embedding(src)
if self.sinusoidalpos:
return emb * math.sqrt(self.emb_size)
else:
return emb