Vincentqyw
fix: roma
4c12b36
raw
history blame
4.74 kB
import torch
from torch import nn
from ..dkm import *
from ..encoders import *
def DKMv3(
weights,
h,
w,
symmetric=True,
sample_mode="threshold_balanced",
device=None,
**kwargs
):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gp_dim = 256
dfn_dim = 384
feat_dim = 256
coordinate_decoder = DFN(
internal_dim=dfn_dim,
feat_input_modules=nn.ModuleDict(
{
"32": nn.Conv2d(512, feat_dim, 1, 1),
"16": nn.Conv2d(512, feat_dim, 1, 1),
}
),
pred_input_modules=nn.ModuleDict(
{
"32": nn.Identity(),
"16": nn.Identity(),
}
),
rrb_d_dict=nn.ModuleDict(
{
"32": RRB(gp_dim + feat_dim, dfn_dim),
"16": RRB(gp_dim + feat_dim, dfn_dim),
}
),
cab_dict=nn.ModuleDict(
{
"32": CAB(2 * dfn_dim, dfn_dim),
"16": CAB(2 * dfn_dim, dfn_dim),
}
),
rrb_u_dict=nn.ModuleDict(
{
"32": RRB(dfn_dim, dfn_dim),
"16": RRB(dfn_dim, dfn_dim),
}
),
terminal_module=nn.ModuleDict(
{
"32": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
"16": nn.Conv2d(dfn_dim, 3, 1, 1, 0),
}
),
)
dw = True
hidden_blocks = 8
kernel_size = 5
displacement_emb = "linear"
conv_refiner = nn.ModuleDict(
{
"16": ConvRefiner(
2 * 512 + 128 + (2 * 7 + 1) ** 2,
2 * 512 + 128 + (2 * 7 + 1) ** 2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=128,
local_corr_radius=7,
corr_in_other=True,
),
"8": ConvRefiner(
2 * 512 + 64 + (2 * 3 + 1) ** 2,
2 * 512 + 64 + (2 * 3 + 1) ** 2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=64,
local_corr_radius=3,
corr_in_other=True,
),
"4": ConvRefiner(
2 * 256 + 32 + (2 * 2 + 1) ** 2,
2 * 256 + 32 + (2 * 2 + 1) ** 2,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=32,
local_corr_radius=2,
corr_in_other=True,
),
"2": ConvRefiner(
2 * 64 + 16,
128 + 16,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=16,
),
"1": ConvRefiner(
2 * 3 + 6,
24,
3,
kernel_size=kernel_size,
dw=dw,
hidden_blocks=hidden_blocks,
displacement_emb=displacement_emb,
displacement_emb_dim=6,
),
}
)
kernel_temperature = 0.2
learn_temperature = False
no_cov = True
kernel = CosKernel
only_attention = False
basis = "fourier"
gp32 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gp16 = GP(
kernel,
T=kernel_temperature,
learn_temperature=learn_temperature,
only_attention=only_attention,
gp_dim=gp_dim,
basis=basis,
no_cov=no_cov,
)
gps = nn.ModuleDict({"32": gp32, "16": gp16})
proj = nn.ModuleDict(
{"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)}
)
decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True)
encoder = ResNet50(pretrained=False, high_res=False, freeze_bn=False)
matcher = RegressionMatcher(
encoder,
decoder,
h=h,
w=w,
name="DKMv3",
sample_mode=sample_mode,
symmetric=symmetric,
**kwargs
).to(device)
res = matcher.load_state_dict(weights)
return matcher