File size: 2,656 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from argparse import Namespace
import torch.nn as nn
import copy
from tencentpretrain.layers.layer_norm import LayerNorm


class DualEmbedding(nn.Module):
    """
    """
    def __init__(self, args, vocab_size):
        super(DualEmbedding, self).__init__()
        from tencentpretrain.embeddings import str2embedding
        from tencentpretrain.embeddings.embedding import Embedding

        stream_0_args = copy.deepcopy(vars(args))
        stream_0_args.update(args.stream_0)
        stream_0_args = Namespace(**stream_0_args)
        self.embedding_0 = Embedding(stream_0_args)
        for embedding_name in stream_0_args.embedding:
            self.embedding_0.update(str2embedding[embedding_name](stream_0_args, vocab_size), embedding_name)
        self.stream_0_remove_embedding_layernorm = stream_0_args.remove_embedding_layernorm
        if not self.stream_0_remove_embedding_layernorm:
            self.stream_0_layer_norm = LayerNorm(stream_0_args.emb_size)

        stream_1_args = copy.deepcopy(vars(args))
        stream_1_args.update(args.stream_1)
        stream_1_args = Namespace(**stream_1_args)
        self.embedding_1 = Embedding(stream_1_args)
        for embedding_name in stream_1_args.embedding:
            self.embedding_1.update(str2embedding[embedding_name](stream_1_args, vocab_size), embedding_name)
        self.stream_1_remove_embedding_layernorm = stream_1_args.remove_embedding_layernorm
        if not self.stream_1_remove_embedding_layernorm:
            self.stream_1_layer_norm = LayerNorm(stream_1_args.emb_size)
        self.dropout = nn.Dropout(args.dropout)

        if args.tie_weights:
            self.embedding_0 = self.embedding_1

    def forward(self, src, seg):
        """
        Args:
            src: ([batch_size x seq_length], [batch_size x seq_length])
            seg: ([batch_size x seq_length], [batch_size x seq_length])
        Returns:
            emb_0: [batch_size x seq_length x hidden_size]
            emb_1: [batch_size x seq_length x hidden_size]
        """
        emb_0 = self.get_embedding_0(src[0], seg[0])
        emb_1 = self.get_embedding_1(src[1], seg[1])

        emb_0 = self.dropout(emb_0)
        emb_1 = self.dropout(emb_1)

        return emb_0, emb_1

    def get_embedding_0(self, src, seg):
        emb = self.embedding_0(src, seg)
        if not self.stream_0_remove_embedding_layernorm:
            emb = self.stream_0_layer_norm(emb)
        return emb

    def get_embedding_1(self, src, seg):
        emb = self.embedding_1(src, seg)
        if not self.stream_1_remove_embedding_layernorm:
            emb = self.stream_1_layer_norm(emb)
        return emb