Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
4.28 kB
import os
import sys
import torch
from pathlib import Path
import torchvision.transforms as tfm
import torch.nn.functional as F
import urllib.request
import numpy as np
from ..utils.base_model import BaseModel
from .. import logger
duster_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(duster_path))
from dust3r.inference import inference
from dust3r.model import load_model
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Duster(BaseModel):
default_conf = {
"name": "Duster3r",
"model_path": duster_path / "model_weights/duster_vit_large.pth",
"max_keypoints": 3000,
"vit_patch_size": 16,
}
def _init(self, conf):
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
self.model_path = self.conf["model_path"]
self.download_weights()
self.net = load_model(self.model_path, device)
logger.info(f"Loaded Dust3r model")
def download_weights(self):
url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
self.model_path.parent.mkdir(parents=True, exist_ok=True)
if not os.path.isfile(self.model_path):
logger.info("Downloading Duster(ViT large)... (takes a while)")
urllib.request.urlretrieve(url, self.model_path)
def preprocess(self, img):
# the super-class already makes sure that img0,img1 have
# same resolution and that h == w
_, h, _ = img.shape
imsize = h
if not ((h % self.vit_patch_size) == 0):
imsize = int(
self.vit_patch_size * round(h / self.vit_patch_size, 0)
)
img = tfm.functional.resize(img, imsize, antialias=True)
_, new_h, new_w = img.shape
if not ((new_w % self.vit_patch_size) == 0):
safe_w = int(
self.vit_patch_size * round(new_w / self.vit_patch_size, 0)
)
img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)
img = self.normalize(img).unsqueeze(0)
return img
def _forward(self, data):
img0, img1 = data["image0"], data["image1"]
# img0 = self.preprocess(img0)
# img1 = self.preprocess(img1)
images = [
{"img": img0, "idx": 0, "instance": 0},
{"img": img1, "idx": 1, "instance": 1},
]
pairs = make_pairs(
images, scene_graph="complete", prefilter=None, symmetrize=True
)
output = inference(pairs, self.net, device, batch_size=1)
scene = global_aligner(
output, device=device, mode=GlobalAlignerMode.PairViewer
)
batch_size = 1
schedule = "cosine"
lr = 0.01
niter = 300
loss = scene.compute_global_alignment(
init="mst", niter=niter, schedule=schedule, lr=lr
)
# retrieve useful values from scene:
confidence_masks = scene.get_masks()
pts3d = scene.get_pts3d()
imgs = scene.imgs
pts2d_list, pts3d_list = [], []
for i in range(2):
conf_i = confidence_masks[i].cpu().numpy()
pts2d_list.append(
xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
) # imgs[i].shape[:2] = (H, W)
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
*pts3d_list
)
print(f"found {num_matches} matches")
mkpts1 = pts2d_list[1][reciprocal_in_P2]
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(mkpts0) > top_k:
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
mkpts0 = mkpts0[keep]
mkpts1 = mkpts1[keep]
breakpoint()
pred = {
"keypoints0": torch.from_numpy(mkpts0),
"keypoints1": torch.from_numpy(mkpts1),
}
return pred