import torch from torch import nn from ..dkm import * from ..encoders import * def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", device = None, **kwargs): if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') gp_dim = 256 dfn_dim = 384 feat_dim = 256 coordinate_decoder = DFN( internal_dim=dfn_dim, feat_input_modules=nn.ModuleDict( { "32": nn.Conv2d(512, feat_dim, 1, 1), "16": nn.Conv2d(512, feat_dim, 1, 1), } ), pred_input_modules=nn.ModuleDict( { "32": nn.Identity(), "16": nn.Identity(), } ), rrb_d_dict=nn.ModuleDict( { "32": RRB(gp_dim + feat_dim, dfn_dim), "16": RRB(gp_dim + feat_dim, dfn_dim), } ), cab_dict=nn.ModuleDict( { "32": CAB(2 * dfn_dim, dfn_dim), "16": CAB(2 * dfn_dim, dfn_dim), } ), rrb_u_dict=nn.ModuleDict( { "32": RRB(dfn_dim, dfn_dim), "16": RRB(dfn_dim, dfn_dim), } ), terminal_module=nn.ModuleDict( { "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), } ), ) dw = True hidden_blocks = 8 kernel_size = 5 displacement_emb = "linear" conv_refiner = nn.ModuleDict( { "16": ConvRefiner( 2 * 512+128+(2*7+1)**2, 2 * 512+128+(2*7+1)**2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=128, local_corr_radius = 7, corr_in_other = True, ), "8": ConvRefiner( 2 * 512+64+(2*3+1)**2, 2 * 512+64+(2*3+1)**2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=64, local_corr_radius = 3, corr_in_other = True, ), "4": ConvRefiner( 2 * 256+32+(2*2+1)**2, 2 * 256+32+(2*2+1)**2, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=32, local_corr_radius = 2, corr_in_other = True, ), "2": ConvRefiner( 2 * 64+16, 128+16, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=16, ), "1": ConvRefiner( 2 * 3+6, 24, 3, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=6, ), } ) kernel_temperature = 0.2 learn_temperature = False no_cov = True kernel = CosKernel only_attention = False basis = "fourier" gp32 = GP( kernel, T=kernel_temperature, learn_temperature=learn_temperature, only_attention=only_attention, gp_dim=gp_dim, basis=basis, no_cov=no_cov, ) gp16 = GP( kernel, T=kernel_temperature, learn_temperature=learn_temperature, only_attention=only_attention, gp_dim=gp_dim, basis=basis, no_cov=no_cov, ) gps = nn.ModuleDict({"32": gp32, "16": gp16}) proj = nn.ModuleDict( {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} ) decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) encoder = ResNet50(pretrained = False, high_res = False, freeze_bn=False) matcher = RegressionMatcher(encoder, decoder, h=h, w=w, name = "DKMv3", sample_mode=sample_mode, symmetric = symmetric, **kwargs).to(device) res = matcher.load_state_dict(weights) return matcher