|
import torch |
|
|
|
from torch import nn |
|
from ..dkm import * |
|
from ..encoders import * |
|
|
|
|
|
def DKMv3( |
|
weights, |
|
h, |
|
w, |
|
symmetric=True, |
|
sample_mode="threshold_balanced", |
|
device=None, |
|
**kwargs |
|
): |
|
if device is None: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
gp_dim = 256 |
|
dfn_dim = 384 |
|
feat_dim = 256 |
|
coordinate_decoder = DFN( |
|
internal_dim=dfn_dim, |
|
feat_input_modules=nn.ModuleDict( |
|
{ |
|
"32": nn.Conv2d(512, feat_dim, 1, 1), |
|
"16": nn.Conv2d(512, feat_dim, 1, 1), |
|
} |
|
), |
|
pred_input_modules=nn.ModuleDict( |
|
{ |
|
"32": nn.Identity(), |
|
"16": nn.Identity(), |
|
} |
|
), |
|
rrb_d_dict=nn.ModuleDict( |
|
{ |
|
"32": RRB(gp_dim + feat_dim, dfn_dim), |
|
"16": RRB(gp_dim + feat_dim, dfn_dim), |
|
} |
|
), |
|
cab_dict=nn.ModuleDict( |
|
{ |
|
"32": CAB(2 * dfn_dim, dfn_dim), |
|
"16": CAB(2 * dfn_dim, dfn_dim), |
|
} |
|
), |
|
rrb_u_dict=nn.ModuleDict( |
|
{ |
|
"32": RRB(dfn_dim, dfn_dim), |
|
"16": RRB(dfn_dim, dfn_dim), |
|
} |
|
), |
|
terminal_module=nn.ModuleDict( |
|
{ |
|
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), |
|
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), |
|
} |
|
), |
|
) |
|
dw = True |
|
hidden_blocks = 8 |
|
kernel_size = 5 |
|
displacement_emb = "linear" |
|
conv_refiner = nn.ModuleDict( |
|
{ |
|
"16": ConvRefiner( |
|
2 * 512 + 128 + (2 * 7 + 1) ** 2, |
|
2 * 512 + 128 + (2 * 7 + 1) ** 2, |
|
3, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=128, |
|
local_corr_radius=7, |
|
corr_in_other=True, |
|
), |
|
"8": ConvRefiner( |
|
2 * 512 + 64 + (2 * 3 + 1) ** 2, |
|
2 * 512 + 64 + (2 * 3 + 1) ** 2, |
|
3, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=64, |
|
local_corr_radius=3, |
|
corr_in_other=True, |
|
), |
|
"4": ConvRefiner( |
|
2 * 256 + 32 + (2 * 2 + 1) ** 2, |
|
2 * 256 + 32 + (2 * 2 + 1) ** 2, |
|
3, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=32, |
|
local_corr_radius=2, |
|
corr_in_other=True, |
|
), |
|
"2": ConvRefiner( |
|
2 * 64 + 16, |
|
128 + 16, |
|
3, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=16, |
|
), |
|
"1": ConvRefiner( |
|
2 * 3 + 6, |
|
24, |
|
3, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=6, |
|
), |
|
} |
|
) |
|
kernel_temperature = 0.2 |
|
learn_temperature = False |
|
no_cov = True |
|
kernel = CosKernel |
|
only_attention = False |
|
basis = "fourier" |
|
gp32 = GP( |
|
kernel, |
|
T=kernel_temperature, |
|
learn_temperature=learn_temperature, |
|
only_attention=only_attention, |
|
gp_dim=gp_dim, |
|
basis=basis, |
|
no_cov=no_cov, |
|
) |
|
gp16 = GP( |
|
kernel, |
|
T=kernel_temperature, |
|
learn_temperature=learn_temperature, |
|
only_attention=only_attention, |
|
gp_dim=gp_dim, |
|
basis=basis, |
|
no_cov=no_cov, |
|
) |
|
gps = nn.ModuleDict({"32": gp32, "16": gp16}) |
|
proj = nn.ModuleDict( |
|
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} |
|
) |
|
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) |
|
|
|
encoder = ResNet50(pretrained=False, high_res=False, freeze_bn=False) |
|
matcher = RegressionMatcher( |
|
encoder, |
|
decoder, |
|
h=h, |
|
w=w, |
|
name="DKMv3", |
|
sample_mode=sample_mode, |
|
symmetric=symmetric, |
|
**kwargs |
|
).to(device) |
|
res = matcher.load_state_dict(weights) |
|
return matcher |
|
|