Realcat
add: mast3r
f90241e
raw
history blame
3.77 kB
import os
import sys
import urllib.request
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms as tfm
from .. import logger
from ..utils.base_model import BaseModel
mast3r_path = Path(__file__).parent / "../../third_party/mast3r"
sys.path.append(str(mast3r_path))
dust3r_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(dust3r_path))
from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.utils.image import load_images
from hloc.matchers.duster import Duster
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Mast3r(Duster):
default_conf = {
"name": "Mast3r",
"model_path": mast3r_path
/ "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
"max_keypoints": 2000,
"vit_patch_size": 16,
}
def _init(self, conf):
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.model_path = self.conf["model_path"]
self.download_weights()
self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device)
logger.info("Loaded Mast3r model")
def download_weights(self):
url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
self.model_path.parent.mkdir(parents=True, exist_ok=True)
if not os.path.isfile(self.model_path):
logger.info("Downloading Mast3r(ViT large)... (takes a while)")
urllib.request.urlretrieve(url, self.model_path)
logger.info("Downloading Mast3r(ViT large)... done!")
def _forward(self, data):
img0, img1 = data["image0"], data["image1"]
mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).to(device)
img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
images = [
{"img": img0, "idx": 0, "instance": 0},
{"img": img1, "idx": 1, "instance": 1},
]
pairs = make_pairs(
images, scene_graph="complete", prefilter=None, symmetrize=True
)
output = inference(pairs, self.net, device, batch_size=1)
# at this stage, you have the raw dust3r predictions
view1, pred1 = output["view1"], output["pred1"]
view2, pred2 = output["view2"], output["pred2"]
desc1, desc2 = (
pred1["desc"][1].squeeze(0).detach(),
pred2["desc"][1].squeeze(0).detach(),
)
# find 2D-2D matches between the two images
matches_im0, matches_im1 = fast_reciprocal_NNs(
desc1,
desc2,
subsample_or_initxy1=2,
device=device,
dist="dot",
block_size=2**13,
)
mkpts0 = matches_im0.copy()
mkpts1 = matches_im1.copy()
if len(mkpts0) == 0:
pred = {
"keypoints0": torch.zeros([0, 2]),
"keypoints1": torch.zeros([0, 2]),
}
logger.warning(f"Matched {0} points")
else:
top_k = self.conf["max_keypoints"]
if top_k is not None and len(mkpts0) > top_k:
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
int
)
mkpts0 = mkpts0[keep]
mkpts1 = mkpts1[keep]
pred = {
"keypoints0": torch.from_numpy(mkpts0),
"keypoints1": torch.from_numpy(mkpts1),
}
return pred