from argparse import Namespace import torch.nn as nn import copy class DualEncoder(nn.Module): """ Dual Encoder which enables siamese models like SBER and CLIP. """ def __init__(self, args): super(DualEncoder, self).__init__() from tencentpretrain.encoders import str2encoder stream_0_args = copy.deepcopy(vars(args)) stream_0_args.update(args.stream_0) stream_0_args = Namespace(**stream_0_args) self.encoder_0 = str2encoder[stream_0_args.encoder](stream_0_args) stream_1_args = copy.deepcopy(vars(args)) stream_1_args.update(args.stream_1) stream_1_args = Namespace(**stream_1_args) self.encoder_1 = str2encoder[stream_1_args.encoder](stream_1_args) if args.tie_weights: self.encoder_1 = self.encoder_0 def forward(self, emb, seg): """ Args: emb: ([batch_size x seq_length x emb_size], [batch_size x seq_length x emb_size]) seg: ([batch_size x seq_length], [batch_size x seq_length]) Returns: features_0: [batch_size x seq_length x hidden_size] features_1: [batch_size x seq_length x hidden_size] """ features_0 = self.get_encode_0(emb[0], seg[0]) features_1 = self.get_encode_1(emb[1], seg[1]) return features_0, features_1 def get_encode_0(self, emb, seg): features = self.encoder_0(emb, seg) return features def get_encode_1(self, emb, seg): features = self.encoder_1(emb, seg) return features