Spaces:
Running
Running
File size: 3,772 Bytes
f90241e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
import sys
import urllib.request
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms as tfm
from .. import logger
from ..utils.base_model import BaseModel
mast3r_path = Path(__file__).parent / "../../third_party/mast3r"
sys.path.append(str(mast3r_path))
dust3r_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(dust3r_path))
from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.utils.image import load_images
from hloc.matchers.duster import Duster
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Mast3r(Duster):
default_conf = {
"name": "Mast3r",
"model_path": mast3r_path
/ "model_weights/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
"max_keypoints": 2000,
"vit_patch_size": 16,
}
def _init(self, conf):
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.model_path = self.conf["model_path"]
self.download_weights()
self.net = AsymmetricMASt3R.from_pretrained(self.model_path).to(device)
logger.info("Loaded Mast3r model")
def download_weights(self):
url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
self.model_path.parent.mkdir(parents=True, exist_ok=True)
if not os.path.isfile(self.model_path):
logger.info("Downloading Mast3r(ViT large)... (takes a while)")
urllib.request.urlretrieve(url, self.model_path)
logger.info("Downloading Mast3r(ViT large)... done!")
def _forward(self, data):
img0, img1 = data["image0"], data["image1"]
mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).to(device)
img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
images = [
{"img": img0, "idx": 0, "instance": 0},
{"img": img1, "idx": 1, "instance": 1},
]
pairs = make_pairs(
images, scene_graph="complete", prefilter=None, symmetrize=True
)
output = inference(pairs, self.net, device, batch_size=1)
# at this stage, you have the raw dust3r predictions
view1, pred1 = output["view1"], output["pred1"]
view2, pred2 = output["view2"], output["pred2"]
desc1, desc2 = (
pred1["desc"][1].squeeze(0).detach(),
pred2["desc"][1].squeeze(0).detach(),
)
# find 2D-2D matches between the two images
matches_im0, matches_im1 = fast_reciprocal_NNs(
desc1,
desc2,
subsample_or_initxy1=2,
device=device,
dist="dot",
block_size=2**13,
)
mkpts0 = matches_im0.copy()
mkpts1 = matches_im1.copy()
if len(mkpts0) == 0:
pred = {
"keypoints0": torch.zeros([0, 2]),
"keypoints1": torch.zeros([0, 2]),
}
logger.warning(f"Matched {0} points")
else:
top_k = self.conf["max_keypoints"]
if top_k is not None and len(mkpts0) > top_k:
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
int
)
mkpts0 = mkpts0[keep]
mkpts1 = mkpts1[keep]
pred = {
"keypoints0": torch.from_numpy(mkpts0),
"keypoints1": torch.from_numpy(mkpts1),
}
return pred
|