Vincentqyw
fix: roma
c74a070
import math
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import get_tuple_transform_ops
from einops import rearrange
from ..utils.local_correlation import local_correlation
class ConvRefiner(nn.Module):
def __init__(
self,
in_dim=6,
hidden_dim=16,
out_dim=2,
dw=False,
kernel_size=5,
hidden_blocks=3,
displacement_emb=None,
displacement_emb_dim=None,
local_corr_radius=None,
corr_in_other=None,
no_support_fm=False,
):
super().__init__()
self.block1 = self.create_block(
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size
)
self.hidden_blocks = nn.Sequential(
*[
self.create_block(
hidden_dim,
hidden_dim,
dw=dw,
kernel_size=kernel_size,
)
for hb in range(hidden_blocks)
]
)
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
if displacement_emb:
self.has_displacement_emb = True
self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0)
else:
self.has_displacement_emb = False
self.local_corr_radius = local_corr_radius
self.corr_in_other = corr_in_other
self.no_support_fm = no_support_fm
def create_block(
self,
in_dim,
out_dim,
dw=False,
kernel_size=5,
):
num_groups = 1 if not dw else in_dim
if dw:
assert (
out_dim % in_dim == 0
), "outdim must be divisible by indim for depthwise"
conv1 = nn.Conv2d(
in_dim,
out_dim,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=num_groups,
)
norm = nn.BatchNorm2d(out_dim)
relu = nn.ReLU(inplace=True)
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
return nn.Sequential(conv1, norm, relu, conv2)
def forward(self, x, y, flow):
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
Args:
x ([type]): [description]
y ([type]): [description]
flow ([type]): [description]
Returns:
[type]: [description]
"""
device = x.device
b, c, hs, ws = x.shape
with torch.no_grad():
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
if self.has_displacement_emb:
query_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
)
)
query_coords = torch.stack((query_coords[1], query_coords[0]))
query_coords = query_coords[None].expand(b, 2, hs, ws)
in_displacement = flow - query_coords
emb_in_displacement = self.disp_emb(in_displacement)
if self.local_corr_radius:
# TODO: should corr have gradient?
if self.corr_in_other:
# Corr in other means take a kxk grid around the predicted coordinate in other image
local_corr = local_correlation(
x, y, local_radius=self.local_corr_radius, flow=flow
)
else:
# Otherwise we use the warp to sample in the first image
# This is actually different operations, especially for large viewpoint changes
local_corr = local_correlation(
x,
x_hat,
local_radius=self.local_corr_radius,
)
if self.no_support_fm:
x_hat = torch.zeros_like(x)
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
else:
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
else:
if self.no_support_fm:
x_hat = torch.zeros_like(x)
d = torch.cat((x, x_hat), dim=1)
d = self.block1(d)
d = self.hidden_blocks(d)
d = self.out_conv(d)
certainty, displacement = d[:, :-2], d[:, -2:]
return certainty, displacement
class CosKernel(nn.Module): # similar to softmax kernel
def __init__(self, T, learn_temperature=False):
super().__init__()
self.learn_temperature = learn_temperature
if self.learn_temperature:
self.T = nn.Parameter(torch.tensor(T))
else:
self.T = T
def __call__(self, x, y, eps=1e-6):
c = torch.einsum("bnd,bmd->bnm", x, y) / (
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
)
if self.learn_temperature:
T = self.T.abs() + 0.01
else:
T = torch.tensor(self.T, device=c.device)
K = ((c - 1.0) / T).exp()
return K
class CAB(nn.Module):
def __init__(self, in_channels, out_channels):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.sigmod = nn.Sigmoid()
def forward(self, x):
x1, x2 = x # high, low (old, new)
x = torch.cat([x1, x2], dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
class RRB(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3):
super(RRB, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.conv2 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
)
def forward(self, x):
x = self.conv1(x)
res = self.conv2(x)
res = self.bn(res)
res = self.relu(res)
res = self.conv3(res)
return self.relu(x + res)
class DFN(nn.Module):
def __init__(
self,
internal_dim,
feat_input_modules,
pred_input_modules,
rrb_d_dict,
cab_dict,
rrb_u_dict,
use_global_context=False,
global_dim=None,
terminal_module=None,
upsample_mode="bilinear",
align_corners=False,
):
super().__init__()
if use_global_context:
assert (
global_dim is not None
), "Global dim must be provided when using global context"
self.align_corners = align_corners
self.internal_dim = internal_dim
self.feat_input_modules = feat_input_modules
self.pred_input_modules = pred_input_modules
self.rrb_d = rrb_d_dict
self.cab = cab_dict
self.rrb_u = rrb_u_dict
self.use_global_context = use_global_context
if use_global_context:
self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.terminal_module = (
terminal_module if terminal_module is not None else nn.Identity()
)
self.upsample_mode = upsample_mode
self._scales = [int(key) for key in self.terminal_module.keys()]
def scales(self):
return self._scales.copy()
def forward(self, embeddings, feats, context, key):
feats = self.feat_input_modules[str(key)](feats)
embeddings = torch.cat([feats, embeddings], dim=1)
embeddings = self.rrb_d[str(key)](embeddings)
context = self.cab[str(key)]([context, embeddings])
context = self.rrb_u[str(key)](context)
preds = self.terminal_module[str(key)](context)
pred_coord = preds[:, -2:]
pred_certainty = preds[:, :-2]
return pred_coord, pred_certainty, context
class GP(nn.Module):
def __init__(
self,
kernel,
T=1,
learn_temperature=False,
only_attention=False,
gp_dim=64,
basis="fourier",
covar_size=5,
only_nearest_neighbour=False,
sigma_noise=0.1,
no_cov=False,
predict_features=False,
):
super().__init__()
self.K = kernel(T=T, learn_temperature=learn_temperature)
self.sigma_noise = sigma_noise
self.covar_size = covar_size
self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
self.only_attention = only_attention
self.only_nearest_neighbour = only_nearest_neighbour
self.basis = basis
self.no_cov = no_cov
self.dim = gp_dim
self.predict_features = predict_features
def get_local_cov(self, cov):
K = self.covar_size
b, h, w, h, w = cov.shape
hw = h * w
cov = F.pad(cov, 4 * (K // 2,)) # pad v_q
delta = torch.stack(
torch.meshgrid(
torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
),
dim=-1,
)
positions = torch.stack(
torch.meshgrid(
torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
),
dim=-1,
)
neighbours = positions[:, :, None, None, :] + delta[None, :, :]
points = torch.arange(hw)[:, None].expand(hw, K**2)
local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
:,
points.flatten(),
neighbours[..., 0].flatten(),
neighbours[..., 1].flatten(),
].reshape(b, h, w, K**2)
return local_cov
def reshape(self, x):
return rearrange(x, "b d h w -> b (h w) d")
def project_to_basis(self, x):
if self.basis == "fourier":
return torch.cos(8 * math.pi * self.pos_conv(x))
elif self.basis == "linear":
return self.pos_conv(x)
else:
raise ValueError(
"No other bases other than fourier and linear currently supported in public release"
)
def get_pos_enc(self, y):
b, c, h, w = y.shape
coarse_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
)
)
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
None
].expand(b, h, w, 2)
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
coarse_embedded_coords = self.project_to_basis(coarse_coords)
return coarse_embedded_coords
def forward(self, x, y, **kwargs):
b, c, h1, w1 = x.shape
b, c, h2, w2 = y.shape
f = self.get_pos_enc(y)
if self.predict_features:
f = f + y[:, : self.dim] # Stupid way to predict features
b, d, h2, w2 = f.shape
# assert x.shape == y.shape
x, y, f = self.reshape(x), self.reshape(y), self.reshape(f)
K_xx = self.K(x, x)
K_yy = self.K(y, y)
K_xy = self.K(x, y)
K_yx = K_xy.permute(0, 2, 1)
sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
# Due to https://github.com/pytorch/pytorch/issues/16963 annoying warnings, remove batch if N large
if len(K_yy[0]) > 2000:
K_yy_inv = torch.cat(
[
torch.linalg.inv(K_yy[k : k + 1] + sigma_noise[k : k + 1])
for k in range(b)
]
)
else:
K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
mu_x = K_xy.matmul(K_yy_inv.matmul(f))
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
if not self.no_cov:
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
cov_x = rearrange(
cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
)
local_cov_x = self.get_local_cov(cov_x)
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
else:
gp_feats = mu_x
return gp_feats
class Encoder(nn.Module):
def __init__(self, resnet):
super().__init__()
self.resnet = resnet
def forward(self, x):
x0 = x
b, c, h, w = x.shape
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x1 = self.resnet.relu(x)
x = self.resnet.maxpool(x1)
x2 = self.resnet.layer1(x)
x3 = self.resnet.layer2(x2)
x4 = self.resnet.layer3(x3)
x5 = self.resnet.layer4(x4)
feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0}
return feats
def train(self, mode=True):
super().train(mode)
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass
class Decoder(nn.Module):
def __init__(
self,
embedding_decoder,
gps,
proj,
conv_refiner,
transformers=None,
detach=False,
scales="all",
pos_embeddings=None,
):
super().__init__()
self.embedding_decoder = embedding_decoder
self.gps = gps
self.proj = proj
self.conv_refiner = conv_refiner
self.detach = detach
if scales == "all":
self.scales = ["32", "16", "8", "4", "2", "1"]
else:
self.scales = scales
def upsample_preds(self, flow, certainty, query, support):
b, hs, ws, d = flow.shape
b, c, h, w = query.shape
flow = flow.permute(0, 3, 1, 2)
certainty = F.interpolate(
certainty, size=(h, w), align_corners=False, mode="bilinear"
)
flow = F.interpolate(flow, size=(h, w), align_corners=False, mode="bilinear")
delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow)
flow = torch.stack(
(
flow[:, 0] + delta_flow[:, 0] / (4 * w),
flow[:, 1] + delta_flow[:, 1] / (4 * h),
),
dim=1,
)
flow = flow.permute(0, 2, 3, 1)
certainty = certainty + delta_certainty
return flow, certainty
def get_placeholder_flow(self, b, h, w, device):
coarse_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
)
)
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
None
].expand(b, h, w, 2)
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
return coarse_coords
def forward(self, f1, f2, upsample=False, dense_flow=None, dense_certainty=None):
coarse_scales = self.embedding_decoder.scales()
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
h, w = sizes[1]
b = f1[1].shape[0]
device = f1[1].device
coarsest_scale = int(all_scales[0])
old_stuff = torch.zeros(
b,
self.embedding_decoder.internal_dim,
*sizes[coarsest_scale],
device=f1[coarsest_scale].device
)
dense_corresps = {}
if not upsample:
dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
dense_certainty = 0.0
else:
dense_flow = F.interpolate(
dense_flow,
size=sizes[coarsest_scale],
align_corners=False,
mode="bilinear",
)
dense_certainty = F.interpolate(
dense_certainty,
size=sizes[coarsest_scale],
align_corners=False,
mode="bilinear",
)
for new_scale in all_scales:
ins = int(new_scale)
f1_s, f2_s = f1[ins], f2[ins]
if new_scale in self.proj:
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
b, c, hs, ws = f1_s.shape
if ins in coarse_scales:
old_stuff = F.interpolate(
old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
)
new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow)
dense_flow, dense_certainty, old_stuff = self.embedding_decoder(
new_stuff, f1_s, old_stuff, new_scale
)
if new_scale in self.conv_refiner:
delta_certainty, displacement = self.conv_refiner[new_scale](
f1_s, f2_s, dense_flow
)
dense_flow = torch.stack(
(
dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w),
dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h),
),
dim=1,
)
dense_certainty = (
dense_certainty + delta_certainty
) # predict both certainty and displacement
dense_corresps[ins] = {
"dense_flow": dense_flow,
"dense_certainty": dense_certainty,
}
if new_scale != "1":
dense_flow = F.interpolate(
dense_flow,
size=sizes[ins // 2],
align_corners=False,
mode="bilinear",
)
dense_certainty = F.interpolate(
dense_certainty,
size=sizes[ins // 2],
align_corners=False,
mode="bilinear",
)
if self.detach:
dense_flow = dense_flow.detach()
dense_certainty = dense_certainty.detach()
return dense_corresps
class RegressionMatcher(nn.Module):
def __init__(
self,
encoder,
decoder,
h=384,
w=512,
use_contrastive_loss=False,
alpha=1,
beta=0,
sample_mode="threshold",
upsample_preds=False,
symmetric=False,
name=None,
use_soft_mutual_nearest_neighbours=False,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.w_resized = w
self.h_resized = h
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
self.use_contrastive_loss = use_contrastive_loss
self.alpha = alpha
self.beta = beta
self.sample_mode = sample_mode
self.upsample_preds = upsample_preds
self.symmetric = symmetric
self.name = name
self.sample_thresh = 0.05
self.upsample_res = (864, 1152)
if use_soft_mutual_nearest_neighbours:
assert symmetric, "MNS requires symmetric inference"
self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours
def extract_backbone_features(self, batch, batched=True, upsample=True):
# TODO: only extract stride [1,2,4,8] for upsample = True
x_q = batch["query"]
x_s = batch["support"]
if batched:
X = torch.cat((x_q, x_s))
feature_pyramid = self.encoder(X)
else:
feature_pyramid = self.encoder(x_q), self.encoder(x_s)
return feature_pyramid
def sample(
self,
dense_matches,
dense_certainty,
num=10000,
):
if "threshold" in self.sample_mode:
upper_thresh = self.sample_thresh
dense_certainty = dense_certainty.clone()
dense_certainty[dense_certainty > upper_thresh] = 1
elif "pow" in self.sample_mode:
dense_certainty = dense_certainty ** (1 / 3)
elif "naive" in self.sample_mode:
dense_certainty = torch.ones_like(dense_certainty)
matches, certainty = (
dense_matches.reshape(-1, 4),
dense_certainty.reshape(-1),
)
expansion_factor = 4 if "balanced" in self.sample_mode else 1
good_samples = torch.multinomial(
certainty,
num_samples=min(expansion_factor * num, len(certainty)),
replacement=False,
)
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
if "balanced" not in self.sample_mode:
return good_matches, good_certainty
from ..utils.kde import kde
density = kde(good_matches, std=0.1)
p = 1 / (density + 1)
p[
density < 10
] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
balanced_samples = torch.multinomial(
p, num_samples=min(num, len(good_certainty)), replacement=False
)
return good_matches[balanced_samples], good_certainty[balanced_samples]
def forward(self, batch, batched=True):
feature_pyramid = self.extract_backbone_features(batch, batched=batched)
if batched:
f_q_pyramid = {
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
}
f_s_pyramid = {
scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
}
else:
f_q_pyramid, f_s_pyramid = feature_pyramid
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid)
if self.training and self.use_contrastive_loss:
return dense_corresps, (f_q_pyramid, f_s_pyramid)
else:
return dense_corresps
def forward_symmetric(self, batch, upsample=False, batched=True):
feature_pyramid = self.extract_backbone_features(
batch, upsample=upsample, batched=batched
)
f_q_pyramid = feature_pyramid
f_s_pyramid = {
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]))
for scale, f_scale in feature_pyramid.items()
}
dense_corresps = self.decoder(
f_q_pyramid,
f_s_pyramid,
upsample=upsample,
**(batch["corresps"] if "corresps" in batch else {})
)
return dense_corresps
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
kpts_A, kpts_B = matches[..., :2], matches[..., 2:]
kpts_A = torch.stack(
(W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1
)
kpts_B = torch.stack(
(W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
)
return kpts_A, kpts_B
def match(self, im1_path, im2_path, *args, batched=False, device=None):
assert not (
batched and self.upsample_preds
), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False "
if isinstance(im1_path, (str, os.PathLike)):
im1, im2 = Image.open(im1_path), Image.open(im2_path)
else: # assume it is a PIL Image
im1, im2 = im1_path, im2_path
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
symmetric = self.symmetric
self.train(False)
with torch.no_grad():
if not batched:
b = 1
w, h = im1.size
w2, h2 = im2.size
# Get images in good format
ws = self.w_resized
hs = self.h_resized
test_transform = get_tuple_transform_ops(
resize=(hs, ws), normalize=True
)
query, support = test_transform((im1, im2))
batch = {
"query": query[None].to(device),
"support": support[None].to(device),
}
else:
b, c, h, w = im1.shape
b, c, h2, w2 = im2.shape
assert w == w2 and h == h2, "For batched images we assume same size"
batch = {"query": im1.to(device), "support": im2.to(device)}
hs, ws = self.h_resized, self.w_resized
finest_scale = 1
# Run matcher
if symmetric:
dense_corresps = self.forward_symmetric(batch, batched=True)
else:
dense_corresps = self.forward(batch, batched=True)
if self.upsample_preds:
hs, ws = self.upsample_res
low_res_certainty = F.interpolate(
dense_corresps[16]["dense_certainty"],
size=(hs, ws),
align_corners=False,
mode="bilinear",
)
cert_clamp = 0
factor = 0.5
low_res_certainty = (
factor * low_res_certainty * (low_res_certainty < cert_clamp)
)
if self.upsample_preds:
test_transform = get_tuple_transform_ops(
resize=(hs, ws), normalize=True
)
query, support = test_transform((im1, im2))
query, support = query[None].to(device), support[None].to(device)
batch = {
"query": query,
"support": support,
"corresps": dense_corresps[finest_scale],
}
if symmetric:
dense_corresps = self.forward_symmetric(
batch, upsample=True, batched=True
)
else:
dense_corresps = self.forward(batch, batched=True, upsample=True)
query_to_support = dense_corresps[finest_scale]["dense_flow"]
dense_certainty = dense_corresps[finest_scale]["dense_certainty"]
# Get certainty interpolation
dense_certainty = dense_certainty - low_res_certainty
query_to_support = query_to_support.permute(0, 2, 3, 1)
# Create im1 meshgrid
query_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
)
)
query_coords = torch.stack((query_coords[1], query_coords[0]))
query_coords = query_coords[None].expand(b, 2, hs, ws)
dense_certainty = dense_certainty.sigmoid() # logits -> probs
query_coords = query_coords.permute(0, 2, 3, 1)
if (query_to_support.abs() > 1).any() and True:
wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0
dense_certainty[wrong[:, None]] = 0
query_to_support = torch.clamp(query_to_support, -1, 1)
if symmetric:
support_coords = query_coords
qts, stq = query_to_support.chunk(2)
q_warp = torch.cat((query_coords, qts), dim=-1)
s_warp = torch.cat((stq, support_coords), dim=-1)
warp = torch.cat((q_warp, s_warp), dim=2)
dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:, 0]
else:
warp = torch.cat((query_coords, query_to_support), dim=-1)
if batched:
return (warp, dense_certainty)
else:
return (
warp[0],
dense_certainty[0],
)