|
import math |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from ..utils import get_tuple_transform_ops |
|
from einops import rearrange |
|
from ..utils.local_correlation import local_correlation |
|
|
|
|
|
class ConvRefiner(nn.Module): |
|
def __init__( |
|
self, |
|
in_dim=6, |
|
hidden_dim=16, |
|
out_dim=2, |
|
dw=False, |
|
kernel_size=5, |
|
hidden_blocks=3, |
|
displacement_emb=None, |
|
displacement_emb_dim=None, |
|
local_corr_radius=None, |
|
corr_in_other=None, |
|
no_support_fm=False, |
|
): |
|
super().__init__() |
|
self.block1 = self.create_block( |
|
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size |
|
) |
|
self.hidden_blocks = nn.Sequential( |
|
*[ |
|
self.create_block( |
|
hidden_dim, |
|
hidden_dim, |
|
dw=dw, |
|
kernel_size=kernel_size, |
|
) |
|
for hb in range(hidden_blocks) |
|
] |
|
) |
|
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0) |
|
if displacement_emb: |
|
self.has_displacement_emb = True |
|
self.disp_emb = nn.Conv2d(2, displacement_emb_dim, 1, 1, 0) |
|
else: |
|
self.has_displacement_emb = False |
|
self.local_corr_radius = local_corr_radius |
|
self.corr_in_other = corr_in_other |
|
self.no_support_fm = no_support_fm |
|
|
|
def create_block( |
|
self, |
|
in_dim, |
|
out_dim, |
|
dw=False, |
|
kernel_size=5, |
|
): |
|
num_groups = 1 if not dw else in_dim |
|
if dw: |
|
assert ( |
|
out_dim % in_dim == 0 |
|
), "outdim must be divisible by indim for depthwise" |
|
conv1 = nn.Conv2d( |
|
in_dim, |
|
out_dim, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=kernel_size // 2, |
|
groups=num_groups, |
|
) |
|
norm = nn.BatchNorm2d(out_dim) |
|
relu = nn.ReLU(inplace=True) |
|
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0) |
|
return nn.Sequential(conv1, norm, relu, conv2) |
|
|
|
def forward(self, x, y, flow): |
|
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them |
|
|
|
Args: |
|
x ([type]): [description] |
|
y ([type]): [description] |
|
flow ([type]): [description] |
|
|
|
Returns: |
|
[type]: [description] |
|
""" |
|
device = x.device |
|
b, c, hs, ws = x.shape |
|
with torch.no_grad(): |
|
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False) |
|
if self.has_displacement_emb: |
|
query_coords = torch.meshgrid( |
|
( |
|
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), |
|
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), |
|
) |
|
) |
|
query_coords = torch.stack((query_coords[1], query_coords[0])) |
|
query_coords = query_coords[None].expand(b, 2, hs, ws) |
|
in_displacement = flow - query_coords |
|
emb_in_displacement = self.disp_emb(in_displacement) |
|
if self.local_corr_radius: |
|
|
|
if self.corr_in_other: |
|
|
|
local_corr = local_correlation( |
|
x, y, local_radius=self.local_corr_radius, flow=flow |
|
) |
|
else: |
|
|
|
|
|
local_corr = local_correlation( |
|
x, |
|
x_hat, |
|
local_radius=self.local_corr_radius, |
|
) |
|
if self.no_support_fm: |
|
x_hat = torch.zeros_like(x) |
|
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1) |
|
else: |
|
d = torch.cat((x, x_hat, emb_in_displacement), dim=1) |
|
else: |
|
if self.no_support_fm: |
|
x_hat = torch.zeros_like(x) |
|
d = torch.cat((x, x_hat), dim=1) |
|
d = self.block1(d) |
|
d = self.hidden_blocks(d) |
|
d = self.out_conv(d) |
|
certainty, displacement = d[:, :-2], d[:, -2:] |
|
return certainty, displacement |
|
|
|
|
|
class CosKernel(nn.Module): |
|
def __init__(self, T, learn_temperature=False): |
|
super().__init__() |
|
self.learn_temperature = learn_temperature |
|
if self.learn_temperature: |
|
self.T = nn.Parameter(torch.tensor(T)) |
|
else: |
|
self.T = T |
|
|
|
def __call__(self, x, y, eps=1e-6): |
|
c = torch.einsum("bnd,bmd->bnm", x, y) / ( |
|
x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps |
|
) |
|
if self.learn_temperature: |
|
T = self.T.abs() + 0.01 |
|
else: |
|
T = torch.tensor(self.T, device=c.device) |
|
K = ((c - 1.0) / T).exp() |
|
return K |
|
|
|
|
|
class CAB(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(CAB, self).__init__() |
|
self.global_pooling = nn.AdaptiveAvgPool2d(1) |
|
self.conv1 = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.relu = nn.ReLU() |
|
self.conv2 = nn.Conv2d( |
|
out_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.sigmod = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x1, x2 = x |
|
x = torch.cat([x1, x2], dim=1) |
|
x = self.global_pooling(x) |
|
x = self.conv1(x) |
|
x = self.relu(x) |
|
x = self.conv2(x) |
|
x = self.sigmod(x) |
|
x2 = x * x2 |
|
res = x2 + x1 |
|
return res |
|
|
|
|
|
class RRB(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3): |
|
super(RRB, self).__init__() |
|
self.conv1 = nn.Conv2d( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.conv2 = nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=kernel_size // 2, |
|
) |
|
self.relu = nn.ReLU() |
|
self.bn = nn.BatchNorm2d(out_channels) |
|
self.conv3 = nn.Conv2d( |
|
out_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=kernel_size // 2, |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
res = self.conv2(x) |
|
res = self.bn(res) |
|
res = self.relu(res) |
|
res = self.conv3(res) |
|
return self.relu(x + res) |
|
|
|
|
|
class DFN(nn.Module): |
|
def __init__( |
|
self, |
|
internal_dim, |
|
feat_input_modules, |
|
pred_input_modules, |
|
rrb_d_dict, |
|
cab_dict, |
|
rrb_u_dict, |
|
use_global_context=False, |
|
global_dim=None, |
|
terminal_module=None, |
|
upsample_mode="bilinear", |
|
align_corners=False, |
|
): |
|
super().__init__() |
|
if use_global_context: |
|
assert ( |
|
global_dim is not None |
|
), "Global dim must be provided when using global context" |
|
self.align_corners = align_corners |
|
self.internal_dim = internal_dim |
|
self.feat_input_modules = feat_input_modules |
|
self.pred_input_modules = pred_input_modules |
|
self.rrb_d = rrb_d_dict |
|
self.cab = cab_dict |
|
self.rrb_u = rrb_u_dict |
|
self.use_global_context = use_global_context |
|
if use_global_context: |
|
self.global_to_internal = nn.Conv2d(global_dim, self.internal_dim, 1, 1, 0) |
|
self.global_pooling = nn.AdaptiveAvgPool2d(1) |
|
self.terminal_module = ( |
|
terminal_module if terminal_module is not None else nn.Identity() |
|
) |
|
self.upsample_mode = upsample_mode |
|
self._scales = [int(key) for key in self.terminal_module.keys()] |
|
|
|
def scales(self): |
|
return self._scales.copy() |
|
|
|
def forward(self, embeddings, feats, context, key): |
|
feats = self.feat_input_modules[str(key)](feats) |
|
embeddings = torch.cat([feats, embeddings], dim=1) |
|
embeddings = self.rrb_d[str(key)](embeddings) |
|
context = self.cab[str(key)]([context, embeddings]) |
|
context = self.rrb_u[str(key)](context) |
|
preds = self.terminal_module[str(key)](context) |
|
pred_coord = preds[:, -2:] |
|
pred_certainty = preds[:, :-2] |
|
return pred_coord, pred_certainty, context |
|
|
|
|
|
class GP(nn.Module): |
|
def __init__( |
|
self, |
|
kernel, |
|
T=1, |
|
learn_temperature=False, |
|
only_attention=False, |
|
gp_dim=64, |
|
basis="fourier", |
|
covar_size=5, |
|
only_nearest_neighbour=False, |
|
sigma_noise=0.1, |
|
no_cov=False, |
|
predict_features=False, |
|
): |
|
super().__init__() |
|
self.K = kernel(T=T, learn_temperature=learn_temperature) |
|
self.sigma_noise = sigma_noise |
|
self.covar_size = covar_size |
|
self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1) |
|
self.only_attention = only_attention |
|
self.only_nearest_neighbour = only_nearest_neighbour |
|
self.basis = basis |
|
self.no_cov = no_cov |
|
self.dim = gp_dim |
|
self.predict_features = predict_features |
|
|
|
def get_local_cov(self, cov): |
|
K = self.covar_size |
|
b, h, w, h, w = cov.shape |
|
hw = h * w |
|
cov = F.pad(cov, 4 * (K // 2,)) |
|
delta = torch.stack( |
|
torch.meshgrid( |
|
torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1) |
|
), |
|
dim=-1, |
|
) |
|
positions = torch.stack( |
|
torch.meshgrid( |
|
torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2) |
|
), |
|
dim=-1, |
|
) |
|
neighbours = positions[:, :, None, None, :] + delta[None, :, :] |
|
points = torch.arange(hw)[:, None].expand(hw, K**2) |
|
local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[ |
|
:, |
|
points.flatten(), |
|
neighbours[..., 0].flatten(), |
|
neighbours[..., 1].flatten(), |
|
].reshape(b, h, w, K**2) |
|
return local_cov |
|
|
|
def reshape(self, x): |
|
return rearrange(x, "b d h w -> b (h w) d") |
|
|
|
def project_to_basis(self, x): |
|
if self.basis == "fourier": |
|
return torch.cos(8 * math.pi * self.pos_conv(x)) |
|
elif self.basis == "linear": |
|
return self.pos_conv(x) |
|
else: |
|
raise ValueError( |
|
"No other bases other than fourier and linear currently supported in public release" |
|
) |
|
|
|
def get_pos_enc(self, y): |
|
b, c, h, w = y.shape |
|
coarse_coords = torch.meshgrid( |
|
( |
|
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device), |
|
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device), |
|
) |
|
) |
|
|
|
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ |
|
None |
|
].expand(b, h, w, 2) |
|
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") |
|
coarse_embedded_coords = self.project_to_basis(coarse_coords) |
|
return coarse_embedded_coords |
|
|
|
def forward(self, x, y, **kwargs): |
|
b, c, h1, w1 = x.shape |
|
b, c, h2, w2 = y.shape |
|
f = self.get_pos_enc(y) |
|
if self.predict_features: |
|
f = f + y[:, : self.dim] |
|
b, d, h2, w2 = f.shape |
|
|
|
x, y, f = self.reshape(x), self.reshape(y), self.reshape(f) |
|
K_xx = self.K(x, x) |
|
K_yy = self.K(y, y) |
|
K_xy = self.K(x, y) |
|
K_yx = K_xy.permute(0, 2, 1) |
|
sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :] |
|
|
|
if len(K_yy[0]) > 2000: |
|
K_yy_inv = torch.cat( |
|
[ |
|
torch.linalg.inv(K_yy[k : k + 1] + sigma_noise[k : k + 1]) |
|
for k in range(b) |
|
] |
|
) |
|
else: |
|
K_yy_inv = torch.linalg.inv(K_yy + sigma_noise) |
|
|
|
mu_x = K_xy.matmul(K_yy_inv.matmul(f)) |
|
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1) |
|
if not self.no_cov: |
|
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx)) |
|
cov_x = rearrange( |
|
cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1 |
|
) |
|
local_cov_x = self.get_local_cov(cov_x) |
|
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w") |
|
gp_feats = torch.cat((mu_x, local_cov_x), dim=1) |
|
else: |
|
gp_feats = mu_x |
|
return gp_feats |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, resnet): |
|
super().__init__() |
|
self.resnet = resnet |
|
|
|
def forward(self, x): |
|
x0 = x |
|
b, c, h, w = x.shape |
|
x = self.resnet.conv1(x) |
|
x = self.resnet.bn1(x) |
|
x1 = self.resnet.relu(x) |
|
|
|
x = self.resnet.maxpool(x1) |
|
x2 = self.resnet.layer1(x) |
|
|
|
x3 = self.resnet.layer2(x2) |
|
|
|
x4 = self.resnet.layer3(x3) |
|
|
|
x5 = self.resnet.layer4(x4) |
|
feats = {32: x5, 16: x4, 8: x3, 4: x2, 2: x1, 1: x0} |
|
return feats |
|
|
|
def train(self, mode=True): |
|
super().train(mode) |
|
for m in self.modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
m.eval() |
|
pass |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_decoder, |
|
gps, |
|
proj, |
|
conv_refiner, |
|
transformers=None, |
|
detach=False, |
|
scales="all", |
|
pos_embeddings=None, |
|
): |
|
super().__init__() |
|
self.embedding_decoder = embedding_decoder |
|
self.gps = gps |
|
self.proj = proj |
|
self.conv_refiner = conv_refiner |
|
self.detach = detach |
|
if scales == "all": |
|
self.scales = ["32", "16", "8", "4", "2", "1"] |
|
else: |
|
self.scales = scales |
|
|
|
def upsample_preds(self, flow, certainty, query, support): |
|
b, hs, ws, d = flow.shape |
|
b, c, h, w = query.shape |
|
flow = flow.permute(0, 3, 1, 2) |
|
certainty = F.interpolate( |
|
certainty, size=(h, w), align_corners=False, mode="bilinear" |
|
) |
|
flow = F.interpolate(flow, size=(h, w), align_corners=False, mode="bilinear") |
|
delta_certainty, delta_flow = self.conv_refiner["1"](query, support, flow) |
|
flow = torch.stack( |
|
( |
|
flow[:, 0] + delta_flow[:, 0] / (4 * w), |
|
flow[:, 1] + delta_flow[:, 1] / (4 * h), |
|
), |
|
dim=1, |
|
) |
|
flow = flow.permute(0, 2, 3, 1) |
|
certainty = certainty + delta_certainty |
|
return flow, certainty |
|
|
|
def get_placeholder_flow(self, b, h, w, device): |
|
coarse_coords = torch.meshgrid( |
|
( |
|
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device), |
|
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device), |
|
) |
|
) |
|
coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[ |
|
None |
|
].expand(b, h, w, 2) |
|
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w") |
|
return coarse_coords |
|
|
|
def forward(self, f1, f2, upsample=False, dense_flow=None, dense_certainty=None): |
|
coarse_scales = self.embedding_decoder.scales() |
|
all_scales = self.scales if not upsample else ["8", "4", "2", "1"] |
|
sizes = {scale: f1[scale].shape[-2:] for scale in f1} |
|
h, w = sizes[1] |
|
b = f1[1].shape[0] |
|
device = f1[1].device |
|
coarsest_scale = int(all_scales[0]) |
|
old_stuff = torch.zeros( |
|
b, |
|
self.embedding_decoder.internal_dim, |
|
*sizes[coarsest_scale], |
|
device=f1[coarsest_scale].device |
|
) |
|
dense_corresps = {} |
|
if not upsample: |
|
dense_flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device) |
|
dense_certainty = 0.0 |
|
else: |
|
dense_flow = F.interpolate( |
|
dense_flow, |
|
size=sizes[coarsest_scale], |
|
align_corners=False, |
|
mode="bilinear", |
|
) |
|
dense_certainty = F.interpolate( |
|
dense_certainty, |
|
size=sizes[coarsest_scale], |
|
align_corners=False, |
|
mode="bilinear", |
|
) |
|
for new_scale in all_scales: |
|
ins = int(new_scale) |
|
f1_s, f2_s = f1[ins], f2[ins] |
|
if new_scale in self.proj: |
|
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s) |
|
b, c, hs, ws = f1_s.shape |
|
if ins in coarse_scales: |
|
old_stuff = F.interpolate( |
|
old_stuff, size=sizes[ins], mode="bilinear", align_corners=False |
|
) |
|
new_stuff = self.gps[new_scale](f1_s, f2_s, dense_flow=dense_flow) |
|
dense_flow, dense_certainty, old_stuff = self.embedding_decoder( |
|
new_stuff, f1_s, old_stuff, new_scale |
|
) |
|
|
|
if new_scale in self.conv_refiner: |
|
delta_certainty, displacement = self.conv_refiner[new_scale]( |
|
f1_s, f2_s, dense_flow |
|
) |
|
dense_flow = torch.stack( |
|
( |
|
dense_flow[:, 0] + ins * displacement[:, 0] / (4 * w), |
|
dense_flow[:, 1] + ins * displacement[:, 1] / (4 * h), |
|
), |
|
dim=1, |
|
) |
|
dense_certainty = ( |
|
dense_certainty + delta_certainty |
|
) |
|
|
|
dense_corresps[ins] = { |
|
"dense_flow": dense_flow, |
|
"dense_certainty": dense_certainty, |
|
} |
|
|
|
if new_scale != "1": |
|
dense_flow = F.interpolate( |
|
dense_flow, |
|
size=sizes[ins // 2], |
|
align_corners=False, |
|
mode="bilinear", |
|
) |
|
|
|
dense_certainty = F.interpolate( |
|
dense_certainty, |
|
size=sizes[ins // 2], |
|
align_corners=False, |
|
mode="bilinear", |
|
) |
|
if self.detach: |
|
dense_flow = dense_flow.detach() |
|
dense_certainty = dense_certainty.detach() |
|
return dense_corresps |
|
|
|
|
|
class RegressionMatcher(nn.Module): |
|
def __init__( |
|
self, |
|
encoder, |
|
decoder, |
|
h=384, |
|
w=512, |
|
use_contrastive_loss=False, |
|
alpha=1, |
|
beta=0, |
|
sample_mode="threshold", |
|
upsample_preds=False, |
|
symmetric=False, |
|
name=None, |
|
use_soft_mutual_nearest_neighbours=False, |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.w_resized = w |
|
self.h_resized = h |
|
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True) |
|
self.use_contrastive_loss = use_contrastive_loss |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.sample_mode = sample_mode |
|
self.upsample_preds = upsample_preds |
|
self.symmetric = symmetric |
|
self.name = name |
|
self.sample_thresh = 0.05 |
|
self.upsample_res = (864, 1152) |
|
if use_soft_mutual_nearest_neighbours: |
|
assert symmetric, "MNS requires symmetric inference" |
|
self.use_soft_mutual_nearest_neighbours = use_soft_mutual_nearest_neighbours |
|
|
|
def extract_backbone_features(self, batch, batched=True, upsample=True): |
|
|
|
x_q = batch["query"] |
|
x_s = batch["support"] |
|
if batched: |
|
X = torch.cat((x_q, x_s)) |
|
feature_pyramid = self.encoder(X) |
|
else: |
|
feature_pyramid = self.encoder(x_q), self.encoder(x_s) |
|
return feature_pyramid |
|
|
|
def sample( |
|
self, |
|
dense_matches, |
|
dense_certainty, |
|
num=10000, |
|
): |
|
if "threshold" in self.sample_mode: |
|
upper_thresh = self.sample_thresh |
|
dense_certainty = dense_certainty.clone() |
|
dense_certainty[dense_certainty > upper_thresh] = 1 |
|
elif "pow" in self.sample_mode: |
|
dense_certainty = dense_certainty ** (1 / 3) |
|
elif "naive" in self.sample_mode: |
|
dense_certainty = torch.ones_like(dense_certainty) |
|
matches, certainty = ( |
|
dense_matches.reshape(-1, 4), |
|
dense_certainty.reshape(-1), |
|
) |
|
expansion_factor = 4 if "balanced" in self.sample_mode else 1 |
|
good_samples = torch.multinomial( |
|
certainty, |
|
num_samples=min(expansion_factor * num, len(certainty)), |
|
replacement=False, |
|
) |
|
good_matches, good_certainty = matches[good_samples], certainty[good_samples] |
|
if "balanced" not in self.sample_mode: |
|
return good_matches, good_certainty |
|
|
|
from ..utils.kde import kde |
|
|
|
density = kde(good_matches, std=0.1) |
|
p = 1 / (density + 1) |
|
p[ |
|
density < 10 |
|
] = 1e-7 |
|
balanced_samples = torch.multinomial( |
|
p, num_samples=min(num, len(good_certainty)), replacement=False |
|
) |
|
return good_matches[balanced_samples], good_certainty[balanced_samples] |
|
|
|
def forward(self, batch, batched=True): |
|
feature_pyramid = self.extract_backbone_features(batch, batched=batched) |
|
if batched: |
|
f_q_pyramid = { |
|
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items() |
|
} |
|
f_s_pyramid = { |
|
scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items() |
|
} |
|
else: |
|
f_q_pyramid, f_s_pyramid = feature_pyramid |
|
dense_corresps = self.decoder(f_q_pyramid, f_s_pyramid) |
|
if self.training and self.use_contrastive_loss: |
|
return dense_corresps, (f_q_pyramid, f_s_pyramid) |
|
else: |
|
return dense_corresps |
|
|
|
def forward_symmetric(self, batch, upsample=False, batched=True): |
|
feature_pyramid = self.extract_backbone_features( |
|
batch, upsample=upsample, batched=batched |
|
) |
|
f_q_pyramid = feature_pyramid |
|
f_s_pyramid = { |
|
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0])) |
|
for scale, f_scale in feature_pyramid.items() |
|
} |
|
dense_corresps = self.decoder( |
|
f_q_pyramid, |
|
f_s_pyramid, |
|
upsample=upsample, |
|
**(batch["corresps"] if "corresps" in batch else {}) |
|
) |
|
return dense_corresps |
|
|
|
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B): |
|
kpts_A, kpts_B = matches[..., :2], matches[..., 2:] |
|
kpts_A = torch.stack( |
|
(W_A / 2 * (kpts_A[..., 0] + 1), H_A / 2 * (kpts_A[..., 1] + 1)), axis=-1 |
|
) |
|
kpts_B = torch.stack( |
|
(W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1 |
|
) |
|
return kpts_A, kpts_B |
|
|
|
def match(self, im1_path, im2_path, *args, batched=False, device=None): |
|
assert not ( |
|
batched and self.upsample_preds |
|
), "Cannot upsample preds if in batchmode (as we don't have access to high res images). You can turn off upsample_preds by model.upsample_preds = False " |
|
if isinstance(im1_path, (str, os.PathLike)): |
|
im1, im2 = Image.open(im1_path), Image.open(im2_path) |
|
else: |
|
im1, im2 = im1_path, im2_path |
|
if device is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
symmetric = self.symmetric |
|
self.train(False) |
|
with torch.no_grad(): |
|
if not batched: |
|
b = 1 |
|
w, h = im1.size |
|
w2, h2 = im2.size |
|
|
|
ws = self.w_resized |
|
hs = self.h_resized |
|
|
|
test_transform = get_tuple_transform_ops( |
|
resize=(hs, ws), normalize=True |
|
) |
|
query, support = test_transform((im1, im2)) |
|
batch = { |
|
"query": query[None].to(device), |
|
"support": support[None].to(device), |
|
} |
|
else: |
|
b, c, h, w = im1.shape |
|
b, c, h2, w2 = im2.shape |
|
assert w == w2 and h == h2, "For batched images we assume same size" |
|
batch = {"query": im1.to(device), "support": im2.to(device)} |
|
hs, ws = self.h_resized, self.w_resized |
|
finest_scale = 1 |
|
|
|
if symmetric: |
|
dense_corresps = self.forward_symmetric(batch, batched=True) |
|
else: |
|
dense_corresps = self.forward(batch, batched=True) |
|
|
|
if self.upsample_preds: |
|
hs, ws = self.upsample_res |
|
low_res_certainty = F.interpolate( |
|
dense_corresps[16]["dense_certainty"], |
|
size=(hs, ws), |
|
align_corners=False, |
|
mode="bilinear", |
|
) |
|
cert_clamp = 0 |
|
factor = 0.5 |
|
low_res_certainty = ( |
|
factor * low_res_certainty * (low_res_certainty < cert_clamp) |
|
) |
|
|
|
if self.upsample_preds: |
|
test_transform = get_tuple_transform_ops( |
|
resize=(hs, ws), normalize=True |
|
) |
|
query, support = test_transform((im1, im2)) |
|
query, support = query[None].to(device), support[None].to(device) |
|
batch = { |
|
"query": query, |
|
"support": support, |
|
"corresps": dense_corresps[finest_scale], |
|
} |
|
if symmetric: |
|
dense_corresps = self.forward_symmetric( |
|
batch, upsample=True, batched=True |
|
) |
|
else: |
|
dense_corresps = self.forward(batch, batched=True, upsample=True) |
|
query_to_support = dense_corresps[finest_scale]["dense_flow"] |
|
dense_certainty = dense_corresps[finest_scale]["dense_certainty"] |
|
|
|
|
|
dense_certainty = dense_certainty - low_res_certainty |
|
query_to_support = query_to_support.permute(0, 2, 3, 1) |
|
|
|
query_coords = torch.meshgrid( |
|
( |
|
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), |
|
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), |
|
) |
|
) |
|
query_coords = torch.stack((query_coords[1], query_coords[0])) |
|
query_coords = query_coords[None].expand(b, 2, hs, ws) |
|
dense_certainty = dense_certainty.sigmoid() |
|
query_coords = query_coords.permute(0, 2, 3, 1) |
|
if (query_to_support.abs() > 1).any() and True: |
|
wrong = (query_to_support.abs() > 1).sum(dim=-1) > 0 |
|
dense_certainty[wrong[:, None]] = 0 |
|
|
|
query_to_support = torch.clamp(query_to_support, -1, 1) |
|
if symmetric: |
|
support_coords = query_coords |
|
qts, stq = query_to_support.chunk(2) |
|
q_warp = torch.cat((query_coords, qts), dim=-1) |
|
s_warp = torch.cat((stq, support_coords), dim=-1) |
|
warp = torch.cat((q_warp, s_warp), dim=2) |
|
dense_certainty = torch.cat(dense_certainty.chunk(2), dim=3)[:, 0] |
|
else: |
|
warp = torch.cat((query_coords, query_to_support), dim=-1) |
|
if batched: |
|
return (warp, dense_certainty) |
|
else: |
|
return ( |
|
warp[0], |
|
dense_certainty[0], |
|
) |
|
|