import warnings import torch.nn as nn import torch from roma.models.matcher import * from roma.models.transformer import Block, TransformerDecoder, MemEffAttention from roma.models.encoders import * def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs): # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') gp_dim = 512 feat_dim = 512 decoder_dim = gp_dim + feat_dim cls_to_coord_res = 64 coordinate_decoder = TransformerDecoder( nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), decoder_dim, cls_to_coord_res**2 + 1, is_classifier=True, amp = True, pos_enc = False,) dw = True hidden_blocks = 8 kernel_size = 5 displacement_emb = "linear" disable_local_corr_grad = True conv_refiner = nn.ModuleDict( { "16": ConvRefiner( 2 * 512+128+(2*7+1)**2, 2 * 512+128+(2*7+1)**2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=128, local_corr_radius = 7, corr_in_other = True, amp = True, disable_local_corr_grad = disable_local_corr_grad, bn_momentum = 0.01, ), "8": ConvRefiner( 2 * 512+64+(2*3+1)**2, 2 * 512+64+(2*3+1)**2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=64, local_corr_radius = 3, corr_in_other = True, amp = True, disable_local_corr_grad = disable_local_corr_grad, bn_momentum = 0.01, ), "4": ConvRefiner( 2 * 256+32+(2*2+1)**2, 2 * 256+32+(2*2+1)**2, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=32, local_corr_radius = 2, corr_in_other = True, amp = True, disable_local_corr_grad = disable_local_corr_grad, bn_momentum = 0.01, ), "2": ConvRefiner( 2 * 64+16, 128+16, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks=hidden_blocks, displacement_emb=displacement_emb, displacement_emb_dim=16, amp = True, disable_local_corr_grad = disable_local_corr_grad, bn_momentum = 0.01, ), "1": ConvRefiner( 2 * 9 + 6, 24, 2 + 1, kernel_size=kernel_size, dw=dw, hidden_blocks = hidden_blocks, displacement_emb = displacement_emb, displacement_emb_dim = 6, amp = True, disable_local_corr_grad = disable_local_corr_grad, bn_momentum = 0.01, ), } ) kernel_temperature = 0.2 learn_temperature = False no_cov = True kernel = CosKernel only_attention = False basis = "fourier" gp16 = GP( kernel, T=kernel_temperature, learn_temperature=learn_temperature, only_attention=only_attention, gp_dim=gp_dim, basis=basis, no_cov=no_cov, ) gps = nn.ModuleDict({"16": gp16}) proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) proj = nn.ModuleDict({ "16": proj16, "8": proj8, "4": proj4, "2": proj2, "1": proj1, }) displacement_dropout_p = 0.0 gm_warp_dropout_p = 0.0 decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True, scales=["16", "8", "4", "2", "1"], displacement_dropout_p = displacement_dropout_p, gm_warp_dropout_p = gm_warp_dropout_p) encoder = CNNandDinov2( cnn_kwargs = dict( pretrained=False, amp = True), amp = True, use_vgg = True, dinov2_weights = dinov2_weights, amp_dtype=amp_dtype, ) h,w = resolution symmetric = True attenuate_cert = True sample_mode = "threshold_balanced" matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device) matcher.load_state_dict(weights) return matcher