import torch import torch.nn.functional as F from utils.utils import bilinear_sampler, coords_grid try: import alt_cuda_corr except: # alt_cuda_corr is not compiled pass class CorrBlock2: def __init__(self, fmap1, fmap2, args): self.num_levels = args.corr_levels self.radius = args.corr_radius self.args = args self.corr_pyramid = [] # all pairs correlation for i in range(self.num_levels): corr = CorrBlock2.corr(fmap1, fmap2, 1) batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.reshape(batch*h1*w1, dim, h2, w2) fmap2 = F.interpolate(fmap2, scale_factor=0.5, mode='bilinear', align_corners=False) self.corr_pyramid.append(corr) def __call__(self, coords, dilation=None): r = self.radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape if dilation is None: dilation = torch.ones(batch, 1, h1, w1, device=coords.device) # print(dilation.max(), dilation.mean(), dilation.min()) out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] device = coords.device dx = torch.linspace(-r, r, 2*r+1, device=device) dy = torch.linspace(-r, r, 2*r+1, device=device) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) delta_lvl = delta_lvl * dilation.view(batch * h1 * w1, 1, 1, 1) centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) corr = corr.view(batch, h1, w1, -1) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) out = out.permute(0, 3, 1, 2).contiguous().float() return out @staticmethod def corr(fmap1, fmap2, num_head): batch, dim, h1, w1 = fmap1.shape h2, w2 = fmap2.shape[2:] fmap1 = fmap1.view(batch, num_head, dim // num_head, h1*w1) fmap2 = fmap2.view(batch, num_head, dim // num_head, h2*w2) corr = fmap1.transpose(2, 3) @ fmap2 corr = corr.reshape(batch, num_head, h1, w1, h2, w2).permute(0, 2, 3, 1, 4, 5) return corr / torch.sqrt(torch.tensor(dim).float()) class CorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.corr_pyramid = [] # all pairs correlation corr = CorrBlock.corr(fmap1, fmap2) batch, h1, w1, dim, h2, w2 = corr.shape corr = corr.reshape(batch*h1*w1, dim, h2, w2) self.corr_pyramid.append(corr) for i in range(self.num_levels-1): corr = F.avg_pool2d(corr, 2, stride=2) self.corr_pyramid.append(corr) def __call__(self, coords): r = self.radius coords = coords.permute(0, 2, 3, 1) batch, h1, w1, _ = coords.shape out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] dx = torch.linspace(-r, r, 2*r+1) dy = torch.linspace(-r, r, 2*r+1) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) corr = corr.view(batch, h1, w1, -1) out_pyramid.append(corr) out = torch.cat(out_pyramid, dim=-1) return out.permute(0, 3, 1, 2).contiguous().float() @staticmethod def corr(fmap1, fmap2): batch, dim, ht, wd = fmap1.shape fmap1 = fmap1.view(batch, dim, ht*wd) fmap2 = fmap2.view(batch, dim, ht*wd) corr = torch.matmul(fmap1.transpose(1,2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) return corr / torch.sqrt(torch.tensor(dim).float()) class AlternateCorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): self.num_levels = num_levels self.radius = radius self.pyramid = [(fmap1, fmap2)] for i in range(self.num_levels): fmap1 = F.avg_pool2d(fmap1, 2, stride=2) fmap2 = F.avg_pool2d(fmap2, 2, stride=2) self.pyramid.append((fmap1, fmap2)) def __call__(self, coords): coords = coords.permute(0, 2, 3, 1) B, H, W, _ = coords.shape dim = self.pyramid[0][0].shape[1] corr_list = [] for i in range(self.num_levels): r = self.radius fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) corr_list.append(corr.squeeze(1)) corr = torch.stack(corr_list, dim=1) corr = corr.reshape(B, -1, H, W) return corr / torch.sqrt(torch.tensor(dim).float())