import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from update import BasicUpdateBlock, SmallUpdateBlock, BasicUpdateBlock2 from extractor import BasicEncoder, SmallEncoder, ResNetFPN from corr import CorrBlock, AlternateCorrBlock, CorrBlock2 from utils.utils import bilinear_sampler, coords_grid, upflow8, InputPadder, coords_grid2 from layer import conv3x3 import math try: autocast = torch.cuda.amp.autocast except: # dummy autocast for PyTorch < 1.6 class autocast: def __init__(self, enabled): pass def __enter__(self): pass def __exit__(self, *args): pass class RAFT(nn.Module): def __init__(self, args): super(RAFT, self).__init__() self.args = args if args.small: self.hidden_dim = hdim = 96 self.context_dim = cdim = 64 args.corr_levels = 4 args.corr_radius = 3 else: self.hidden_dim = hdim = 128 self.context_dim = cdim = 128 args.corr_levels = 4 args.corr_radius = 4 if 'dropout' not in self.args: self.args.dropout = 0 if 'alternate_corr' not in self.args: self.args.alternate_corr = False # feature network, context network, and update block if args.small: self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) else: self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def initialize_flow(self, img): """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" N, C, H, W = img.shape coords0 = coords_grid(N, H//8, W//8).to(img.device) coords1 = coords_grid(N, H//8, W//8).to(img.device) # optical flow computed as difference: flow = coords1 - coords0 return coords0, coords1 def upsample_flow(self, flow, mask): """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ N, _, H, W = flow.shape mask = mask.view(N, 1, 9, 8, 8, H, W) mask = torch.softmax(mask, dim=2) up_flow = F.unfold(8 * flow, [3,3], padding=1) up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 2, 8*H, 8*W) def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): """ Estimate optical flow between pair of frames """ image1 = 2 * (image1 / 255.0) - 1.0 image2 = 2 * (image2 / 255.0) - 1.0 image1 = image1.contiguous() image2 = image2.contiguous() hdim = self.hidden_dim cdim = self.context_dim # run the feature network with autocast(enabled=self.args.mixed_precision): fmap1, fmap2 = self.fnet([image1, image2]) fmap1 = fmap1.float() fmap2 = fmap2.float() if self.args.alternate_corr: corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) else: corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) # run the context network with autocast(enabled=self.args.mixed_precision): cnet = self.cnet(image1) net, inp = torch.split(cnet, [hdim, cdim], dim=1) net = torch.tanh(net) inp = torch.relu(inp) coords0, coords1 = self.initialize_flow(image1) if flow_init is not None: coords1 = coords1 + flow_init flow_predictions = [] for itr in range(iters): coords1 = coords1.detach() corr = corr_fn(coords1) # index correlation volume flow = coords1 - coords0 with autocast(enabled=self.args.mixed_precision): net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) # F(t+1) = F(t) + \Delta(t) coords1 = coords1 + delta_flow # upsample predictions if up_mask is None: flow_up = upflow8(coords1 - coords0) else: flow_up = self.upsample_flow(coords1 - coords0, up_mask) flow_predictions.append(flow_up) if test_mode: return coords1 - coords0, flow_up return flow_predictions ## # given depth, warp according to camera params. # given flow+depth, warp in 2D class RAFT2(nn.Module): def __init__(self, args): super(RAFT2, self).__init__() self.args = args self.output_dim = args.dim * 2 self.args.corr_levels = 4 self.args.corr_radius = args.radius self.args.corr_channel = args.corr_levels * (args.radius * 2 + 1) ** 2 self.cnet = ResNetFPN(args, input_dim=6, output_dim=2 * self.args.dim, norm_layer=nn.BatchNorm2d, init_weight=True) # conv for iter 0 results self.init_conv = conv3x3(2 * args.dim, 2 * args.dim) self.upsample_weight = nn.Sequential( # convex combination of 3x3 patches nn.Conv2d(args.dim, args.dim * 2, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(args.dim * 2, 64 * 9, 1, padding=0) ) self.flow_head = nn.Sequential( # flow(2) + weight(2) + log_b(2) nn.Conv2d(args.dim, 2 * args.dim, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(2 * args.dim, 6, 3, padding=1) ) if args.iters > 0: self.fnet = ResNetFPN(args, input_dim=3, output_dim=self.output_dim, norm_layer=nn.BatchNorm2d, init_weight=True) self.update_block = BasicUpdateBlock2(args, hdim=args.dim, cdim=args.dim) def initialize_flow(self, img): """ Flow is represented as difference between two coordinate grids flow = coords2 - coords1""" N, C, H, W = img.shape coords1 = coords_grid(N, H//8, W//8, device=img.device) coords2 = coords_grid(N, H//8, W//8, device=img.device) return coords1, coords2 def upsample_data(self, flow, info, mask): """ Upsample [H/8, W/8, C] -> [H, W, C] using convex combination """ N, C, H, W = info.shape mask = mask.view(N, 1, 9, 8, 8, H, W) mask = torch.softmax(mask, dim=2) up_flow = F.unfold(8 * flow, [3,3], padding=1) up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) up_info = F.unfold(info, [3, 3], padding=1) up_info = up_info.view(N, C, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) up_info = torch.sum(mask * up_info, dim=2) up_info = up_info.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 2, 8*H, 8*W), up_info.reshape(N, C, 8*H, 8*W) def forward(self, image1, image2, iters=None, flow_gt=None, test_mode=False): """ Estimate optical flow between pair of frames """ N, _, H, W = image1.shape if iters is None: iters = self.args.iters if flow_gt is None: flow_gt = torch.zeros(N, 2, H, W, device=image1.device) image1 = 2 * (image1 / 255.0) - 1.0 image2 = 2 * (image2 / 255.0) - 1.0 image1 = image1.contiguous() image2 = image2.contiguous() flow_predictions = [] info_predictions = [] # padding padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) N, _, H, W = image1.shape dilation = torch.ones(N, 1, H//8, W//8, device=image1.device) # run the context network cnet = self.cnet(torch.cat([image1, image2], dim=1)) cnet = self.init_conv(cnet) net, context = torch.split(cnet, [self.args.dim, self.args.dim], dim=1) # init flow flow_update = self.flow_head(net) weight_update = .25 * self.upsample_weight(net) flow_8x = flow_update[:, :2] info_8x = flow_update[:, 2:] flow_up, info_up = self.upsample_data(flow_8x, info_8x, weight_update) flow_predictions.append(flow_up) info_predictions.append(info_up) if self.args.iters > 0: # run the feature network fmap1_8x = self.fnet(image1) fmap2_8x = self.fnet(image2) corr_fn = CorrBlock2(fmap1_8x, fmap2_8x, self.args) for itr in range(iters): N, _, H, W = flow_8x.shape flow_8x = flow_8x.detach() coords2 = (coords_grid2(N, H, W, device=image1.device) + flow_8x).detach() corr = corr_fn(coords2, dilation=dilation) net = self.update_block(net, context, corr, flow_8x) flow_update = self.flow_head(net) weight_update = .25 * self.upsample_weight(net) flow_8x = flow_8x + flow_update[:, :2] info_8x = flow_update[:, 2:] # upsample predictions flow_up, info_up = self.upsample_data(flow_8x, info_8x, weight_update) flow_predictions.append(flow_up) info_predictions.append(info_up) for i in range(len(info_predictions)): flow_predictions[i] = padder.unpad(flow_predictions[i]) info_predictions[i] = padder.unpad(info_predictions[i]) if test_mode == False: # exlude invalid pixels and extremely large diplacements nf_predictions = [] for i in range(len(info_predictions)): if not self.args.use_var: var_max = var_min = 0 else: var_max = self.args.var_max var_min = self.args.var_min raw_b = info_predictions[i][:, 2:] log_b = torch.zeros_like(raw_b) weight = info_predictions[i][:, :2] # Large b Component log_b[:, 0] = torch.clamp(raw_b[:, 0], min=0, max=var_max) # Small b Component log_b[:, 1] = torch.clamp(raw_b[:, 1], min=var_min, max=0) # term2: [N, 2, m, H, W] term2 = ((flow_gt - flow_predictions[i]).abs().unsqueeze(2)) * (torch.exp(-log_b).unsqueeze(1)) # term1: [N, m, H, W] term1 = weight - math.log(2) - log_b nf_loss = torch.logsumexp(weight, dim=1, keepdim=True) - torch.logsumexp(term1.unsqueeze(1) - term2, dim=2) nf_predictions.append(nf_loss) return {'final': flow_predictions[-1], 'flow': flow_predictions, 'info': info_predictions, 'nf': nf_predictions} else: return [flow_predictions,flow_predictions[-1]]