szukevin's picture
upload
7900c16
raw
history blame
2.66 kB
from argparse import Namespace
import torch.nn as nn
import copy
from tencentpretrain.layers.layer_norm import LayerNorm
class DualEmbedding(nn.Module):
"""
"""
def __init__(self, args, vocab_size):
super(DualEmbedding, self).__init__()
from tencentpretrain.embeddings import str2embedding
from tencentpretrain.embeddings.embedding import Embedding
stream_0_args = copy.deepcopy(vars(args))
stream_0_args.update(args.stream_0)
stream_0_args = Namespace(**stream_0_args)
self.embedding_0 = Embedding(stream_0_args)
for embedding_name in stream_0_args.embedding:
self.embedding_0.update(str2embedding[embedding_name](stream_0_args, vocab_size), embedding_name)
self.stream_0_remove_embedding_layernorm = stream_0_args.remove_embedding_layernorm
if not self.stream_0_remove_embedding_layernorm:
self.stream_0_layer_norm = LayerNorm(stream_0_args.emb_size)
stream_1_args = copy.deepcopy(vars(args))
stream_1_args.update(args.stream_1)
stream_1_args = Namespace(**stream_1_args)
self.embedding_1 = Embedding(stream_1_args)
for embedding_name in stream_1_args.embedding:
self.embedding_1.update(str2embedding[embedding_name](stream_1_args, vocab_size), embedding_name)
self.stream_1_remove_embedding_layernorm = stream_1_args.remove_embedding_layernorm
if not self.stream_1_remove_embedding_layernorm:
self.stream_1_layer_norm = LayerNorm(stream_1_args.emb_size)
self.dropout = nn.Dropout(args.dropout)
if args.tie_weights:
self.embedding_0 = self.embedding_1
def forward(self, src, seg):
"""
Args:
src: ([batch_size x seq_length], [batch_size x seq_length])
seg: ([batch_size x seq_length], [batch_size x seq_length])
Returns:
emb_0: [batch_size x seq_length x hidden_size]
emb_1: [batch_size x seq_length x hidden_size]
"""
emb_0 = self.get_embedding_0(src[0], seg[0])
emb_1 = self.get_embedding_1(src[1], seg[1])
emb_0 = self.dropout(emb_0)
emb_1 = self.dropout(emb_1)
return emb_0, emb_1
def get_embedding_0(self, src, seg):
emb = self.embedding_0(src, seg)
if not self.stream_0_remove_embedding_layernorm:
emb = self.stream_0_layer_norm(emb)
return emb
def get_embedding_1(self, src, seg):
emb = self.embedding_1(src, seg)
if not self.stream_1_remove_embedding_layernorm:
emb = self.stream_1_layer_norm(emb)
return emb