Liangrj5
init
a638e43
"""
Pytorch modules
some classes are modified from HuggingFace
(https://github.com/huggingface/transformers)
"""
import torch
import logging
from torch import nn
logger = logging.getLogger(__name__)
try:
import apex.normalization.fused_layer_norm.FusedLayerNorm as BertLayerNorm
except (ImportError, AttributeError) as e:
BertLayerNorm = torch.nn.LayerNorm
from model.transformer.bert import BertEncoder
from model.layers import (NetVLAD, LinearLayer)
from model.transformer.bert_embed import (BertEmbeddings)
from utils.model_utils import mask_logits
import torch.nn.functional as F
class TransformerBaseModel(nn.Module):
"""
Base Transformer model
"""
def __init__(self, config):
super(TransformerBaseModel, self).__init__()
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
def forward(self,features,position_ids,token_type_ids,attention_mask):
# embedding layer
embedding_output = self.embeddings(token_type_ids=token_type_ids,
inputs_embeds=features,
position_ids=position_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs[0]
return sequence_output
class TwoModalEncoder(nn.Module):
"""
Two modality Transformer Encoder model
"""
def __init__(self, config,img_dim,text_dim,hidden_dim,split_num,output_split=True):
super(TwoModalEncoder, self).__init__()
self.img_linear = LinearLayer(
in_hsz=img_dim, out_hsz=hidden_dim)
self.text_linear = LinearLayer(
in_hsz=text_dim, out_hsz=hidden_dim)
self.transformer = TransformerBaseModel(config)
self.output_split = output_split
if self.output_split:
self.split_num = split_num
def forward(self, visual_features, visual_position_ids, visual_token_type_ids, visual_attention_mask,
text_features,text_position_ids,text_token_type_ids,text_attention_mask):
transformed_im = self.img_linear(visual_features)
transformed_text = self.text_linear(text_features)
transformer_input_feat = torch.cat((transformed_im,transformed_text),dim=1)
transformer_input_feat_pos_id = torch.cat((visual_position_ids,text_position_ids),dim=1)
transformer_input_feat_token_id = torch.cat((visual_token_type_ids,text_token_type_ids),dim=1)
transformer_input_feat_mask = torch.cat((visual_attention_mask,text_attention_mask),dim=1)
output = self.transformer(features=transformer_input_feat,
position_ids=transformer_input_feat_pos_id,
token_type_ids=transformer_input_feat_token_id,
attention_mask=transformer_input_feat_mask)
if self.output_split:
return torch.split(output,self.split_num,dim=1)
else:
return output
class OneModalEncoder(nn.Module):
"""
One modality Transformer Encoder model
"""
def __init__(self, config,input_dim,hidden_dim):
super(OneModalEncoder, self).__init__()
self.linear = LinearLayer(
in_hsz=input_dim, out_hsz=hidden_dim)
self.transformer = TransformerBaseModel(config)
def forward(self, features, position_ids, token_type_ids, attention_mask):
transformed_features = self.linear(features)
output = self.transformer(features=transformed_features,
position_ids=position_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask)
return output
class VideoQueryEncoder(nn.Module):
def __init__(self, config, video_modality,
visual_dim=4352, text_dim= 768,
query_dim=768, hidden_dim = 768,split_num=100,):
super(VideoQueryEncoder, self).__init__()
self.use_sub = len(video_modality) > 1
if self.use_sub:
self.videoEncoder = TwoModalEncoder(config=config.bert_config,
img_dim = visual_dim,
text_dim = text_dim ,
hidden_dim = hidden_dim,
split_num = split_num
)
else:
self.videoEncoder = OneModalEncoder(config=config.bert_config,
input_dim = visual_dim,
hidden_dim = hidden_dim,
)
self.queryEncoder = OneModalEncoder(config=config.query_bert_config,
input_dim= query_dim,
hidden_dim=hidden_dim,
)
def forward_repr_query(self, batch):
query_output = self.queryEncoder(
features=batch["query"]["feat"],
position_ids=batch["query"]["feat_pos_id"],
token_type_ids=batch["query"]["feat_token_id"],
attention_mask=batch["query"]["feat_mask"]
)
return query_output
def forward_repr_video(self,batch):
video_output = dict()
if len(batch["visual"]["feat"].size()) == 4:
bsz, num_video = batch["visual"]["feat"].size()[:2]
for key in batch.keys():
if key in ["visual", "sub"]:
for key_2 in batch[key]:
if key_2 in ["feat", "feat_mask", "feat_pos_id", "feat_token_id"]:
shape_list = batch[key][key_2].size()[2:]
batch[key][key_2] = batch[key][key_2].view((bsz * num_video,) + shape_list)
if self.use_sub:
video_output["visual"], video_output["sub"] = self.videoEncoder(
visual_features=batch["visual"]["feat"],
visual_position_ids=batch["visual"]["feat_pos_id"],
visual_token_type_ids=batch["visual"]["feat_token_id"],
visual_attention_mask=batch["visual"]["feat_mask"],
text_features=batch["sub"]["feat"],
text_position_ids=batch["sub"]["feat_pos_id"],
text_token_type_ids=batch["sub"]["feat_token_id"],
text_attention_mask=batch["sub"]["feat_mask"]
)
else:
video_output["visual"] = self.videoEncoder(
features=batch["visual"]["feat"],
position_ids=batch["visual"]["feat_pos_id"],
token_type_ids=batch["visual"]["feat_token_id"],
attention_mask=batch["visual"]["feat_mask"]
)
return video_output
def forward_repr_both(self, batch):
video_output = self.forward_repr_video(batch)
query_output = self.forward_repr_query(batch)
return {"video_feat": video_output,
"query_feat": query_output}
def forward(self,batch,task="repr_both"):
if task == "repr_both":
return self.forward_repr_both(batch)
elif task == "repr_video":
return self.forward_repr_video(batch)
elif task == "repr_query":
return self.forward_repr_query(batch)
class QueryWeightEncoder(nn.Module):
"""
Query Weight Encoder
Using NetVLAD to aggreate contextual query features
Using FC + Softmax to get fusion weights for each modality
"""
def __init__(self, config, video_modality):
super(QueryWeightEncoder, self).__init__()
##NetVLAD
self.text_pooling = NetVLAD(feature_size=config.hidden_size,cluster_size=config.text_cluster)
self.moe_txt_dropout = nn.Dropout(config.moe_dropout_prob)
##FC
self.moe_fc_txt = nn.Linear(
in_features=self.text_pooling.out_dim,
out_features=len(video_modality),
bias=False)
self.video_modality = video_modality
def forward(self, query_feat):
##NetVLAD
pooled_text = self.text_pooling(query_feat)
pooled_text = self.moe_txt_dropout(pooled_text)
##FC + Softmax
moe_weights = self.moe_fc_txt(pooled_text)
softmax_moe_weights = F.softmax(moe_weights, dim=1)
moe_weights_dict = dict()
for modality, moe_weight in zip(self.video_modality, torch.split(softmax_moe_weights, 1, dim=1)):
moe_weights_dict[modality] = moe_weight.squeeze(1)
return moe_weights_dict