File size: 4,707 Bytes
4dfb78b
 
 
e64cfb1
 
4dfb78b
e64cfb1
 
 
4dfb78b
e64cfb1
4dfb78b
 
 
 
e64cfb1
4dfb78b
e64cfb1
 
4dfb78b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789fb0a
 
 
 
 
e64cfb1
4dfb78b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789fb0a
 
 
 
 
4dfb78b
 
 
 
 
 
 
 
 
 
 
 
 
789fb0a
4dfb78b
 
 
 
 
 
 
 
 
 
789fb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dfb78b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import sys
import urllib.request
from pathlib import Path

import numpy as np
import torch
import torchvision.transforms as tfm

from .. import logger
from ..utils.base_model import BaseModel

duster_path = Path(__file__).parent / "../../third_party/dust3r"
sys.path.append(str(duster_path))

from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Duster(BaseModel):
    default_conf = {
        "name": "Duster3r",
        "model_path": duster_path / "model_weights/duster_vit_large.pth",
        "max_keypoints": 3000,
        "vit_patch_size": 16,
    }

    def _init(self, conf):
        self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        self.model_path = self.conf["model_path"]
        self.download_weights()
        # self.net = load_model(self.model_path, device)
        self.net = AsymmetricCroCo3DStereo.from_pretrained(
            self.model_path
            # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
        ).to(device)
        logger.info("Loaded Dust3r model")

    def download_weights(self):
        url = "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"

        self.model_path.parent.mkdir(parents=True, exist_ok=True)
        if not os.path.isfile(self.model_path):
            logger.info("Downloading Duster(ViT large)... (takes a while)")
            urllib.request.urlretrieve(url, self.model_path)

    def preprocess(self, img):
        # the super-class already makes sure that img0,img1 have
        # same resolution and that h == w
        _, h, _ = img.shape
        imsize = h
        if not ((h % self.vit_patch_size) == 0):
            imsize = int(
                self.vit_patch_size * round(h / self.vit_patch_size, 0)
            )
            img = tfm.functional.resize(img, imsize, antialias=True)

        _, new_h, new_w = img.shape
        if not ((new_w % self.vit_patch_size) == 0):
            safe_w = int(
                self.vit_patch_size * round(new_w / self.vit_patch_size, 0)
            )
            img = tfm.functional.resize(img, (new_h, safe_w), antialias=True)

        img = self.normalize(img).unsqueeze(0)

        return img

    def _forward(self, data):
        img0, img1 = data["image0"], data["image1"]
        mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
        std = torch.tensor([0.5, 0.5, 0.5]).to(device)

        img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
        img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)

        images = [
            {"img": img0, "idx": 0, "instance": 0},
            {"img": img1, "idx": 1, "instance": 1},
        ]
        pairs = make_pairs(
            images, scene_graph="complete", prefilter=None, symmetrize=True
        )
        output = inference(pairs, self.net, device, batch_size=1)
        scene = global_aligner(
            output, device=device, mode=GlobalAlignerMode.PairViewer
        )
        # retrieve useful values from scene:
        imgs = scene.imgs
        confidence_masks = scene.get_masks()
        pts3d = scene.get_pts3d()
        pts2d_list, pts3d_list = [], []
        for i in range(2):
            conf_i = confidence_masks[i].cpu().numpy()
            pts2d_list.append(
                xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
            )  # imgs[i].shape[:2] = (H, W)
            pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])

        if len(pts3d_list[1]) == 0:
            pred = {
                "keypoints0": torch.zeros([0, 2]),
                "keypoints1": torch.zeros([0, 2]),
            }
            logger.warning(f"Matched {0} points")
        else:
            reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
                *pts3d_list
            )
            logger.info(f"Found {num_matches} matches")
            mkpts1 = pts2d_list[1][reciprocal_in_P2]
            mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
            top_k = self.conf["max_keypoints"]
            if top_k is not None and len(mkpts0) > top_k:
                keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
                    int
                )
                mkpts0 = mkpts0[keep]
                mkpts1 = mkpts1[keep]
            pred = {
                "keypoints0": torch.from_numpy(mkpts0),
                "keypoints1": torch.from_numpy(mkpts1),
            }
        return pred