Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .warplayer import warp_features | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0): | |
super(DecoderBlock, self).__init__() | |
self.conv0 = nn.Sequential( | |
nn.Conv2d(in_planes, c, 3, 2, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 2, 1), | |
nn.PReLU(c), | |
) | |
self.convblocks = nn.ModuleList() | |
for i in range(block_nums): | |
self.convblocks.append(nn.Sequential( | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
nn.Conv2d(c, c, 3, 1, 1), | |
nn.PReLU(c), | |
)) | |
self.out_flows = 2 | |
self.out_msgs = out_msgs | |
self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0 | |
self.out_locals = out_locals | |
self.out_local_flows = out_local_flows if out_locals > 0 else 0 | |
self.out_masks = out_masks | |
self.out_feat_flows = out_feat_flows | |
self.conv_last = nn.Sequential( | |
nn.ConvTranspose2d(c, c, 4, 2, 1), | |
nn.PReLU(c), | |
nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows + | |
self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1), | |
) | |
def forward(self, accumulated_flow, *other): | |
x = [accumulated_flow] | |
for each in other: | |
if each is not None: | |
assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format( | |
accumulated_flow.shape, each.shape) | |
x.append(each) | |
feat = self.conv0(torch.cat(x, dim=1)) | |
for convblock1 in self.convblocks: | |
feat = convblock1(feat) + feat | |
feat = self.conv_last(feat) | |
prev = 0 | |
flow = feat[:, prev:prev+self.out_flows, :, :] | |
prev += self.out_flows | |
message = feat[:, prev:prev+self.out_msgs, | |
:, :] if self.out_msgs > 0 else None | |
prev += self.out_msgs | |
message_flow = feat[:, prev:prev + self.out_msgs_flows, | |
:, :] if self.out_msgs_flows > 0 else None | |
prev += self.out_msgs_flows | |
local_message = feat[:, prev:prev + self.out_locals, | |
:, :] if self.out_locals > 0 else None | |
prev += self.out_locals | |
local_message_flow = feat[:, prev:prev+self.out_local_flows, | |
:, :] if self.out_local_flows > 0 else None | |
prev += self.out_local_flows | |
mask = torch.sigmoid( | |
feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None | |
prev += self.out_masks | |
feat_flow = feat[:, prev:prev+self.out_feat_flows, | |
:, :] if self.out_feat_flows > 0 else None | |
prev += self.out_feat_flows | |
return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow | |
class CINN(nn.Module): | |
def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]): | |
super(CINN, self).__init__() | |
self.in_msgs_chn = [0, *out_msgs_chn[:-1]] | |
self.in_locals_chn = [0, *out_locals_chn[:-1]] | |
self.decoder_blocks = nn.ModuleList() | |
self.feed_weighted = True | |
if self.feed_weighted: | |
in_planes = 2+2+DIM_SHADER_REFERENCE*2 | |
else: | |
in_planes = 2+DIM_SHADER_REFERENCE | |
for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num): | |
self.decoder_blocks.append( | |
DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn)) | |
for i in range(len(feature_chns), len(out_locals_chn)): | |
#print("append extra block", i, "msg", | |
# out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i]) | |
self.decoder_blocks.append( | |
DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0)) | |
def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None): | |
if each_x_reference_features is not None: | |
size_from = each_x_reference_features | |
else: | |
size_from = x_reference | |
f_size = (size_from.shape[2], size_from.shape[3]) | |
accumulated_flow = self.flow_rescale( | |
accumulated_flow, size_from) | |
# mask = warp_features(F.interpolate( | |
# mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None | |
mask = F.interpolate( | |
mask, size=f_size, mode="bilinear") if mask is not None else None | |
message = F.interpolate( | |
message, size=f_size, mode="bilinear") if message is not None else None | |
message_flow = self.flow_rescale( | |
message_flow, size_from) if message_flow is not None else None | |
message = warp_features( | |
message, message_flow) if message_flow is not None else message | |
local_message = F.interpolate( | |
local_message, size=f_size, mode="bilinear") if local_message is not None else None | |
local_message_flow = self.flow_rescale( | |
local_message_flow, size_from) if local_message_flow is not None else None | |
local_message = warp_features( | |
local_message, local_message_flow) if local_message_flow is not None else local_message | |
warp_x_reference = warp_features(F.interpolate( | |
x_reference, size=f_size, mode="bilinear"), accumulated_flow) | |
each_x_reference_features_flow = self.flow_rescale( | |
each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None | |
warp_each_x_reference_features = warp_features( | |
each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features | |
return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow | |
def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]): | |
y_flow = [] | |
y_feat_flow = [] | |
y_local_message = [] | |
y_warp_x_reference = [] | |
y_warp_x_reference_features = [] | |
y_weighted_flow = [] | |
y_weighted_mask = [] | |
y_weighted_message = [] | |
y_weighted_x_reference = [] | |
y_weighted_x_reference_features = [] | |
for pyrlevel, ifblock in enumerate(self.decoder_blocks): | |
stacked_wref = [] | |
stacked_feat = [] | |
stacked_anci = [] | |
stacked_flow = [] | |
stacked_mask = [] | |
stacked_mesg = [] | |
stacked_locm = [] | |
stacked_feat_flow = [] | |
for view_id in range(x_reference.shape[1]): # NMCHW | |
if pyrlevel == 0: | |
# create from zero flow | |
feat_ev = x_reference_features[pyrlevel][:, | |
view_id, :, :, :] if pyrlevel < len(x_reference_features) else None | |
accumulated_flow = torch.zeros_like( | |
feat_ev[:, :2, :, :]).to(device) | |
accumulated_feat_flow = torch.zeros_like( | |
feat_ev[:, :32, :, :]).to(device) | |
# domestic inputs | |
warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=( | |
feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear") | |
warp_x_reference_features = feat_ev | |
local_message = None | |
# federated inputs | |
weighted_flow = accumulated_flow if self.feed_weighted else None | |
weighted_wref = warp_x_reference if self.feed_weighted else None | |
weighted_message = None | |
else: | |
# resume from last layer | |
accumulated_flow = y_flow[-1][:, view_id, :, :, :] | |
accumulated_feat_flow = y_feat_flow[-1][:, | |
view_id, :, :, :] if y_feat_flow[-1] is not None else None | |
# domestic inputs | |
warp_x_reference = y_warp_x_reference[-1][:, | |
view_id, :, :, :] | |
warp_x_reference_features = y_warp_x_reference_features[-1][:, | |
view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None | |
local_message = y_local_message[-1][:, view_id, :, | |
:, :] if len(y_local_message) > 0 else None | |
# federated inputs | |
weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None | |
weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None | |
weighted_message = y_weighted_message[-1] if len( | |
y_weighted_message) > 0 else None | |
scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len( | |
x_target_features) else None | |
# compute flow | |
residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock( | |
accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message) | |
accumulated_flow = residual_flow + accumulated_flow | |
accumulated_feat_flow = accumulated_flow | |
feat_ev = x_reference_features[pyrlevel+1][:, | |
view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None | |
mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow( | |
mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow) | |
stacked_flow.append(accumulated_flow) | |
if accumulated_feat_flow is not None: | |
stacked_feat_flow.append(accumulated_feat_flow) | |
stacked_mask.append(mask) | |
if message is not None: | |
stacked_mesg.append(message) | |
if local_message is not None: | |
stacked_locm.append(local_message) | |
stacked_wref.append(warp_x_reference) | |
if warp_x_reference_features is not None: | |
stacked_feat.append(warp_x_reference_features) | |
stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW | |
stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len( | |
stacked_feat_flow) > 0 else None | |
stacked_mask = torch.stack( | |
stacked_mask, dim=1) | |
stacked_mesg = torch.stack(stacked_mesg, dim=1) if len( | |
stacked_mesg) > 0 else None | |
stacked_locm = torch.stack(stacked_locm, dim=1) if len( | |
stacked_locm) > 0 else None | |
stacked_wref = torch.stack(stacked_wref, dim=1) | |
stacked_feat = torch.stack(stacked_feat, dim=1) if len( | |
stacked_feat) > 0 else None | |
stacked_anci = torch.stack(stacked_anci, dim=1) if len( | |
stacked_anci) > 0 else None | |
y_flow.append(stacked_flow) | |
y_feat_flow.append(stacked_feat_flow) | |
y_warp_x_reference.append(stacked_wref) | |
y_warp_x_reference_features.append(stacked_feat) | |
# compute normalized confidence | |
stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1) | |
# torch.sum to remove temp dimension M from NMCHW --> NCHW | |
weighted_flow = torch.sum( | |
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1) | |
weighted_mask = torch.sum( | |
stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1) | |
weighted_wref = torch.sum( | |
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None | |
weighted_feat = torch.sum( | |
stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None | |
weighted_mesg = torch.sum( | |
stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None | |
y_weighted_flow.append(weighted_flow) | |
y_weighted_mask.append(weighted_mask) | |
if weighted_mesg is not None: | |
y_weighted_message.append(weighted_mesg) | |
if stacked_locm is not None: | |
y_local_message.append(stacked_locm) | |
y_weighted_message.append(weighted_mesg) | |
y_weighted_x_reference.append(weighted_wref) | |
y_weighted_x_reference_features.append(weighted_feat) | |
if weighted_feat is not None: | |
y_weighted_x_reference_features.append(weighted_feat) | |
return { | |
"y_last_remote_features": [weighted_mesg], | |
} | |
def flow_rescale(self, prev_flow, each_x_reference_features): | |
if prev_flow is None: | |
prev_flow = torch.zeros_like( | |
each_x_reference_features[:, :2]).to(device) | |
else: | |
up_scale_factor = each_x_reference_features.shape[-1] / \ | |
prev_flow.shape[-1] | |
if up_scale_factor != 1: | |
prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear", | |
align_corners=False, recompute_scale_factor=False) * up_scale_factor | |
return prev_flow | |