szukevin's picture
upload
7900c16
raw
history blame
1.16 kB
import torch.nn as nn
import torch
from tencentpretrain.layers.layer_norm import LayerNorm
class Embedding(nn.Module):
def __init__(self, args):
super(Embedding, self).__init__()
self.embedding_name_list = []
self.dropout = nn.Dropout(args.dropout)
self.remove_embedding_layernorm = args.remove_embedding_layernorm
if not self.remove_embedding_layernorm and "dual" not in args.embedding:
self.layer_norm = LayerNorm(args.emb_size)
def update(self, embedding, embedding_name):
setattr(self, embedding_name, embedding)
self.embedding_name_list.append(embedding_name)
def forward(self, src, seg):
if self.embedding_name_list[0] == "dual":
return self.dual(src, seg)
for i, embedding_name in enumerate(self.embedding_name_list):
embedding = getattr(self, embedding_name)
if i == 0:
emb = embedding(src, seg)
else:
emb += embedding(src, seg)
if not self.remove_embedding_layernorm:
emb = self.layer_norm(emb)
emb = self.dropout(emb)
return emb