File size: 654 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn

from tencentpretrain.embeddings.word_embedding import WordEmbedding
from tencentpretrain.embeddings.patch_embedding import PatchEmbedding


class WordPatchEmbedding(nn.Module):
    """
    """

    def __init__(self, args, vocab_size):
        super(WordPatchEmbedding, self).__init__()
        self.language_embedding = WordEmbedding(args, vocab_size)
        self.vision_embedding = PatchEmbedding(args, None)


    def forward(self, src, _):
        l_emb = self.language_embedding(src[0], None)
        v_emb = self.vision_embedding(src[1], None)
        emb = torch.cat([l_emb, v_emb], dim=1)

        return emb