|
""" |
|
Pytorch modules |
|
some classes are modified from HuggingFace |
|
(https://github.com/huggingface/transformers) |
|
""" |
|
|
|
import torch |
|
import logging |
|
from torch import nn |
|
logger = logging.getLogger(__name__) |
|
|
|
try: |
|
import apex.normalization.fused_layer_norm.FusedLayerNorm as BertLayerNorm |
|
except (ImportError, AttributeError) as e: |
|
BertLayerNorm = torch.nn.LayerNorm |
|
|
|
from model.transformer.bert import BertEncoder |
|
from model.layers import (NetVLAD, LinearLayer) |
|
from model.transformer.bert_embed import (BertEmbeddings) |
|
from utils.model_utils import mask_logits |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class TransformerBaseModel(nn.Module): |
|
""" |
|
Base Transformer model |
|
""" |
|
def __init__(self, config): |
|
super(TransformerBaseModel, self).__init__() |
|
self.embeddings = BertEmbeddings(config) |
|
self.encoder = BertEncoder(config) |
|
|
|
|
|
def forward(self,features,position_ids,token_type_ids,attention_mask): |
|
|
|
embedding_output = self.embeddings(token_type_ids=token_type_ids, |
|
inputs_embeds=features, |
|
position_ids=position_ids) |
|
|
|
encoder_outputs = self.encoder(embedding_output, attention_mask) |
|
|
|
sequence_output = encoder_outputs[0] |
|
|
|
return sequence_output |
|
|
|
class TwoModalEncoder(nn.Module): |
|
""" |
|
Two modality Transformer Encoder model |
|
""" |
|
|
|
def __init__(self, config,img_dim,text_dim,hidden_dim,split_num,output_split=True): |
|
super(TwoModalEncoder, self).__init__() |
|
self.img_linear = LinearLayer( |
|
in_hsz=img_dim, out_hsz=hidden_dim) |
|
self.text_linear = LinearLayer( |
|
in_hsz=text_dim, out_hsz=hidden_dim) |
|
|
|
self.transformer = TransformerBaseModel(config) |
|
self.output_split = output_split |
|
if self.output_split: |
|
self.split_num = split_num |
|
|
|
|
|
def forward(self, visual_features, visual_position_ids, visual_token_type_ids, visual_attention_mask, |
|
text_features,text_position_ids,text_token_type_ids,text_attention_mask): |
|
|
|
transformed_im = self.img_linear(visual_features) |
|
transformed_text = self.text_linear(text_features) |
|
|
|
transformer_input_feat = torch.cat((transformed_im,transformed_text),dim=1) |
|
transformer_input_feat_pos_id = torch.cat((visual_position_ids,text_position_ids),dim=1) |
|
transformer_input_feat_token_id = torch.cat((visual_token_type_ids,text_token_type_ids),dim=1) |
|
transformer_input_feat_mask = torch.cat((visual_attention_mask,text_attention_mask),dim=1) |
|
|
|
output = self.transformer(features=transformer_input_feat, |
|
position_ids=transformer_input_feat_pos_id, |
|
token_type_ids=transformer_input_feat_token_id, |
|
attention_mask=transformer_input_feat_mask) |
|
|
|
if self.output_split: |
|
return torch.split(output,self.split_num,dim=1) |
|
else: |
|
return output |
|
|
|
|
|
class OneModalEncoder(nn.Module): |
|
""" |
|
One modality Transformer Encoder model |
|
""" |
|
|
|
def __init__(self, config,input_dim,hidden_dim): |
|
super(OneModalEncoder, self).__init__() |
|
self.linear = LinearLayer( |
|
in_hsz=input_dim, out_hsz=hidden_dim) |
|
self.transformer = TransformerBaseModel(config) |
|
|
|
def forward(self, features, position_ids, token_type_ids, attention_mask): |
|
|
|
transformed_features = self.linear(features) |
|
|
|
output = self.transformer(features=transformed_features, |
|
position_ids=position_ids, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask) |
|
return output |
|
|
|
|
|
class VideoQueryEncoder(nn.Module): |
|
def __init__(self, config, video_modality, |
|
visual_dim=4352, text_dim= 768, |
|
query_dim=768, hidden_dim = 768,split_num=100,): |
|
super(VideoQueryEncoder, self).__init__() |
|
self.use_sub = len(video_modality) > 1 |
|
if self.use_sub: |
|
self.videoEncoder = TwoModalEncoder(config=config.bert_config, |
|
img_dim = visual_dim, |
|
text_dim = text_dim , |
|
hidden_dim = hidden_dim, |
|
split_num = split_num |
|
) |
|
else: |
|
self.videoEncoder = OneModalEncoder(config=config.bert_config, |
|
input_dim = visual_dim, |
|
hidden_dim = hidden_dim, |
|
) |
|
|
|
self.queryEncoder = OneModalEncoder(config=config.query_bert_config, |
|
input_dim= query_dim, |
|
hidden_dim=hidden_dim, |
|
) |
|
|
|
def forward_repr_query(self, batch): |
|
|
|
query_output = self.queryEncoder( |
|
features=batch["query"]["feat"], |
|
position_ids=batch["query"]["feat_pos_id"], |
|
token_type_ids=batch["query"]["feat_token_id"], |
|
attention_mask=batch["query"]["feat_mask"] |
|
) |
|
|
|
return query_output |
|
|
|
def forward_repr_video(self,batch): |
|
video_output = dict() |
|
|
|
if len(batch["visual"]["feat"].size()) == 4: |
|
bsz, num_video = batch["visual"]["feat"].size()[:2] |
|
for key in batch.keys(): |
|
if key in ["visual", "sub"]: |
|
for key_2 in batch[key]: |
|
if key_2 in ["feat", "feat_mask", "feat_pos_id", "feat_token_id"]: |
|
shape_list = batch[key][key_2].size()[2:] |
|
batch[key][key_2] = batch[key][key_2].view((bsz * num_video,) + shape_list) |
|
|
|
|
|
if self.use_sub: |
|
video_output["visual"], video_output["sub"] = self.videoEncoder( |
|
visual_features=batch["visual"]["feat"], |
|
visual_position_ids=batch["visual"]["feat_pos_id"], |
|
visual_token_type_ids=batch["visual"]["feat_token_id"], |
|
visual_attention_mask=batch["visual"]["feat_mask"], |
|
text_features=batch["sub"]["feat"], |
|
text_position_ids=batch["sub"]["feat_pos_id"], |
|
text_token_type_ids=batch["sub"]["feat_token_id"], |
|
text_attention_mask=batch["sub"]["feat_mask"] |
|
) |
|
else: |
|
video_output["visual"] = self.videoEncoder( |
|
features=batch["visual"]["feat"], |
|
position_ids=batch["visual"]["feat_pos_id"], |
|
token_type_ids=batch["visual"]["feat_token_id"], |
|
attention_mask=batch["visual"]["feat_mask"] |
|
) |
|
|
|
return video_output |
|
|
|
|
|
def forward_repr_both(self, batch): |
|
video_output = self.forward_repr_video(batch) |
|
query_output = self.forward_repr_query(batch) |
|
|
|
return {"video_feat": video_output, |
|
"query_feat": query_output} |
|
|
|
def forward(self,batch,task="repr_both"): |
|
|
|
if task == "repr_both": |
|
return self.forward_repr_both(batch) |
|
elif task == "repr_video": |
|
return self.forward_repr_video(batch) |
|
elif task == "repr_query": |
|
return self.forward_repr_query(batch) |
|
|
|
|
|
class QueryWeightEncoder(nn.Module): |
|
""" |
|
Query Weight Encoder |
|
Using NetVLAD to aggreate contextual query features |
|
Using FC + Softmax to get fusion weights for each modality |
|
""" |
|
def __init__(self, config, video_modality): |
|
super(QueryWeightEncoder, self).__init__() |
|
|
|
|
|
self.text_pooling = NetVLAD(feature_size=config.hidden_size,cluster_size=config.text_cluster) |
|
self.moe_txt_dropout = nn.Dropout(config.moe_dropout_prob) |
|
|
|
|
|
self.moe_fc_txt = nn.Linear( |
|
in_features=self.text_pooling.out_dim, |
|
out_features=len(video_modality), |
|
bias=False) |
|
|
|
self.video_modality = video_modality |
|
|
|
def forward(self, query_feat): |
|
|
|
pooled_text = self.text_pooling(query_feat) |
|
pooled_text = self.moe_txt_dropout(pooled_text) |
|
|
|
|
|
moe_weights = self.moe_fc_txt(pooled_text) |
|
softmax_moe_weights = F.softmax(moe_weights, dim=1) |
|
|
|
|
|
moe_weights_dict = dict() |
|
for modality, moe_weight in zip(self.video_modality, torch.split(softmax_moe_weights, 1, dim=1)): |
|
moe_weights_dict[modality] = moe_weight.squeeze(1) |
|
|
|
return moe_weights_dict |
|
|
|
|
|
|
|
|
|
|