Spaces:
Running
Running
import warnings | |
import torch.nn as nn | |
from roma.models.matcher import * | |
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention | |
from roma.models.encoders import * | |
def roma_model( | |
resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs | |
): | |
# roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters | |
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul | |
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | |
warnings.filterwarnings( | |
"ignore", category=UserWarning, message="TypedStorage is deprecated" | |
) | |
gp_dim = 512 | |
feat_dim = 512 | |
decoder_dim = gp_dim + feat_dim | |
cls_to_coord_res = 64 | |
coordinate_decoder = TransformerDecoder( | |
nn.Sequential( | |
*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)] | |
), | |
decoder_dim, | |
cls_to_coord_res**2 + 1, | |
is_classifier=True, | |
amp=True, | |
pos_enc=False, | |
) | |
dw = True | |
hidden_blocks = 8 | |
kernel_size = 5 | |
displacement_emb = "linear" | |
disable_local_corr_grad = True | |
conv_refiner = nn.ModuleDict( | |
{ | |
"16": ConvRefiner( | |
2 * 512 + 128 + (2 * 7 + 1) ** 2, | |
2 * 512 + 128 + (2 * 7 + 1) ** 2, | |
2 + 1, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=128, | |
local_corr_radius=7, | |
corr_in_other=True, | |
amp=True, | |
disable_local_corr_grad=disable_local_corr_grad, | |
bn_momentum=0.01, | |
), | |
"8": ConvRefiner( | |
2 * 512 + 64 + (2 * 3 + 1) ** 2, | |
2 * 512 + 64 + (2 * 3 + 1) ** 2, | |
2 + 1, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=64, | |
local_corr_radius=3, | |
corr_in_other=True, | |
amp=True, | |
disable_local_corr_grad=disable_local_corr_grad, | |
bn_momentum=0.01, | |
), | |
"4": ConvRefiner( | |
2 * 256 + 32 + (2 * 2 + 1) ** 2, | |
2 * 256 + 32 + (2 * 2 + 1) ** 2, | |
2 + 1, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=32, | |
local_corr_radius=2, | |
corr_in_other=True, | |
amp=True, | |
disable_local_corr_grad=disable_local_corr_grad, | |
bn_momentum=0.01, | |
), | |
"2": ConvRefiner( | |
2 * 64 + 16, | |
128 + 16, | |
2 + 1, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=16, | |
amp=True, | |
disable_local_corr_grad=disable_local_corr_grad, | |
bn_momentum=0.01, | |
), | |
"1": ConvRefiner( | |
2 * 9 + 6, | |
24, | |
2 + 1, | |
kernel_size=kernel_size, | |
dw=dw, | |
hidden_blocks=hidden_blocks, | |
displacement_emb=displacement_emb, | |
displacement_emb_dim=6, | |
amp=True, | |
disable_local_corr_grad=disable_local_corr_grad, | |
bn_momentum=0.01, | |
), | |
} | |
) | |
kernel_temperature = 0.2 | |
learn_temperature = False | |
no_cov = True | |
kernel = CosKernel | |
only_attention = False | |
basis = "fourier" | |
gp16 = GP( | |
kernel, | |
T=kernel_temperature, | |
learn_temperature=learn_temperature, | |
only_attention=only_attention, | |
gp_dim=gp_dim, | |
basis=basis, | |
no_cov=no_cov, | |
) | |
gps = nn.ModuleDict({"16": gp16}) | |
proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) | |
proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) | |
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) | |
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) | |
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) | |
proj = nn.ModuleDict( | |
{ | |
"16": proj16, | |
"8": proj8, | |
"4": proj4, | |
"2": proj2, | |
"1": proj1, | |
} | |
) | |
displacement_dropout_p = 0.0 | |
gm_warp_dropout_p = 0.0 | |
decoder = Decoder( | |
coordinate_decoder, | |
gps, | |
proj, | |
conv_refiner, | |
detach=True, | |
scales=["16", "8", "4", "2", "1"], | |
displacement_dropout_p=displacement_dropout_p, | |
gm_warp_dropout_p=gm_warp_dropout_p, | |
) | |
encoder = CNNandDinov2( | |
cnn_kwargs=dict(pretrained=False, amp=True), | |
amp=True, | |
use_vgg=True, | |
dinov2_weights=dinov2_weights, | |
) | |
h, w = resolution | |
symmetric = True | |
attenuate_cert = True | |
matcher = RegressionMatcher( | |
encoder, | |
decoder, | |
h=h, | |
w=w, | |
upsample_preds=upsample_preds, | |
symmetric=symmetric, | |
attenuate_cert=attenuate_cert, | |
**kwargs | |
).to(device) | |
matcher.load_state_dict(weights) | |
return matcher | |