|
import warnings |
|
import torch.nn as nn |
|
from roma.models.matcher import * |
|
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention |
|
from roma.models.encoders import * |
|
|
|
|
|
def roma_model( |
|
resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs |
|
): |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
warnings.filterwarnings( |
|
"ignore", category=UserWarning, message="TypedStorage is deprecated" |
|
) |
|
gp_dim = 512 |
|
feat_dim = 512 |
|
decoder_dim = gp_dim + feat_dim |
|
cls_to_coord_res = 64 |
|
coordinate_decoder = TransformerDecoder( |
|
nn.Sequential( |
|
*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)] |
|
), |
|
decoder_dim, |
|
cls_to_coord_res**2 + 1, |
|
is_classifier=True, |
|
amp=True, |
|
pos_enc=False, |
|
) |
|
dw = True |
|
hidden_blocks = 8 |
|
kernel_size = 5 |
|
displacement_emb = "linear" |
|
disable_local_corr_grad = True |
|
|
|
conv_refiner = nn.ModuleDict( |
|
{ |
|
"16": ConvRefiner( |
|
2 * 512 + 128 + (2 * 7 + 1) ** 2, |
|
2 * 512 + 128 + (2 * 7 + 1) ** 2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=128, |
|
local_corr_radius=7, |
|
corr_in_other=True, |
|
amp=True, |
|
disable_local_corr_grad=disable_local_corr_grad, |
|
bn_momentum=0.01, |
|
), |
|
"8": ConvRefiner( |
|
2 * 512 + 64 + (2 * 3 + 1) ** 2, |
|
2 * 512 + 64 + (2 * 3 + 1) ** 2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=64, |
|
local_corr_radius=3, |
|
corr_in_other=True, |
|
amp=True, |
|
disable_local_corr_grad=disable_local_corr_grad, |
|
bn_momentum=0.01, |
|
), |
|
"4": ConvRefiner( |
|
2 * 256 + 32 + (2 * 2 + 1) ** 2, |
|
2 * 256 + 32 + (2 * 2 + 1) ** 2, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=32, |
|
local_corr_radius=2, |
|
corr_in_other=True, |
|
amp=True, |
|
disable_local_corr_grad=disable_local_corr_grad, |
|
bn_momentum=0.01, |
|
), |
|
"2": ConvRefiner( |
|
2 * 64 + 16, |
|
128 + 16, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=16, |
|
amp=True, |
|
disable_local_corr_grad=disable_local_corr_grad, |
|
bn_momentum=0.01, |
|
), |
|
"1": ConvRefiner( |
|
2 * 9 + 6, |
|
24, |
|
2 + 1, |
|
kernel_size=kernel_size, |
|
dw=dw, |
|
hidden_blocks=hidden_blocks, |
|
displacement_emb=displacement_emb, |
|
displacement_emb_dim=6, |
|
amp=True, |
|
disable_local_corr_grad=disable_local_corr_grad, |
|
bn_momentum=0.01, |
|
), |
|
} |
|
) |
|
kernel_temperature = 0.2 |
|
learn_temperature = False |
|
no_cov = True |
|
kernel = CosKernel |
|
only_attention = False |
|
basis = "fourier" |
|
gp16 = GP( |
|
kernel, |
|
T=kernel_temperature, |
|
learn_temperature=learn_temperature, |
|
only_attention=only_attention, |
|
gp_dim=gp_dim, |
|
basis=basis, |
|
no_cov=no_cov, |
|
) |
|
gps = nn.ModuleDict({"16": gp16}) |
|
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) |
|
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) |
|
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) |
|
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) |
|
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) |
|
proj = nn.ModuleDict( |
|
{ |
|
"16": proj16, |
|
"8": proj8, |
|
"4": proj4, |
|
"2": proj2, |
|
"1": proj1, |
|
} |
|
) |
|
displacement_dropout_p = 0.0 |
|
gm_warp_dropout_p = 0.0 |
|
decoder = Decoder( |
|
coordinate_decoder, |
|
gps, |
|
proj, |
|
conv_refiner, |
|
detach=True, |
|
scales=["16", "8", "4", "2", "1"], |
|
displacement_dropout_p=displacement_dropout_p, |
|
gm_warp_dropout_p=gm_warp_dropout_p, |
|
) |
|
|
|
encoder = CNNandDinov2( |
|
cnn_kwargs=dict(pretrained=False, amp=True), |
|
amp=True, |
|
use_vgg=True, |
|
dinov2_weights=dinov2_weights, |
|
) |
|
h, w = resolution |
|
symmetric = True |
|
attenuate_cert = True |
|
matcher = RegressionMatcher( |
|
encoder, |
|
decoder, |
|
h=h, |
|
w=w, |
|
upsample_preds=upsample_preds, |
|
symmetric=symmetric, |
|
attenuate_cert=attenuate_cert, |
|
**kwargs |
|
).to(device) |
|
matcher.load_state_dict(weights) |
|
return matcher |
|
|