File size: 6,088 Bytes
d4c818f
7ef7e3c
d4c818f
 
 
 
7ef7e3c
d4c818f
 
 
 
 
 
 
 
 
 
7ef7e3c
 
 
 
 
 
d4c818f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import shutil
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List

sys.path.append(str(Path(__file__).parents[1]))

from hloc import (
    extract_features,
    logger,
    match_features,
    pairs_from_retrieval,
    reconstruction,
    visualization,
)

try:
    import pycolmap
except ImportError:
    logger.warning("pycolmap not installed, some features may not work")

from ui.viz import fig2im


class SfmEngine:
    def __init__(self, cfg: Dict[str, Any] = None):
        self.cfg = cfg
        if "outputs" in cfg and Path(cfg["outputs"]):
            outputs = Path(cfg["outputs"])
            outputs.mkdir(parents=True, exist_ok=True)
        else:
            outputs = tempfile.mkdtemp()
        self.outputs = Path(outputs)

    def call(
        self,
        key: str,
        images: Path,
        camera_model: str,
        camera_params: List[float],
        max_keypoints: int,
        keypoint_threshold: float,
        match_threshold: float,
        ransac_threshold: int,
        ransac_confidence: float,
        ransac_max_iter: int,
        scene_graph: bool,
        global_feature: str,
        top_k: int = 10,
        mapper_refine_focal_length: bool = False,
        mapper_refine_principle_points: bool = False,
        mapper_refine_extra_params: bool = False,
    ):
        """
        Call a list of functions to perform feature extraction, matching, and reconstruction.

        Args:
            key (str): The key to retrieve the matcher and feature models.
            images (Path): The directory containing the images.
            outputs (Path): The directory to store the outputs.
            camera_model (str): The camera model.
            camera_params (List[float]): The camera parameters.
            max_keypoints (int): The maximum number of features.
            match_threshold (float): The match threshold.
            ransac_threshold (int): The RANSAC threshold.
            ransac_confidence (float): The RANSAC confidence.
            ransac_max_iter (int): The maximum number of RANSAC iterations.
            scene_graph (bool): Whether to compute the scene graph.
            global_feature (str): Whether to compute the global feature.
            top_k (int): The number of image-pair to use.
            mapper_refine_focal_length (bool): Whether to refine the focal length.
            mapper_refine_principle_points (bool): Whether to refine the principle points.
            mapper_refine_extra_params (bool): Whether to refine the extra parameters.

        Returns:
            Path: The directory containing the SfM results.
        """
        if len(images) == 0:
            logger.error(f"{images} does not exist.")

        temp_images = Path(tempfile.mkdtemp())
        # copy images
        logger.info(f"Copying images to {temp_images}.")
        for image in images:
            shutil.copy(image, temp_images)

        matcher_zoo = self.cfg["matcher_zoo"]
        model = matcher_zoo[key]
        match_conf = model["matcher"]
        match_conf["model"]["max_keypoints"] = max_keypoints
        match_conf["model"]["match_threshold"] = match_threshold

        feature_conf = model["feature"]
        feature_conf["model"]["max_keypoints"] = max_keypoints
        feature_conf["model"]["keypoint_threshold"] = keypoint_threshold

        # retrieval
        retrieval_name = self.cfg.get("retrieval_name", "netvlad")
        retrieval_conf = extract_features.confs[retrieval_name]

        mapper_options = {
            "ba_refine_extra_params": mapper_refine_extra_params,
            "ba_refine_focal_length": mapper_refine_focal_length,
            "ba_refine_principal_point": mapper_refine_principle_points,
            "ba_local_max_num_iterations": 40,
            "ba_local_max_refinements": 3,
            "ba_global_max_num_iterations": 100,
            # below 3 options are for individual/video data, for internet photos, they should be left
            # default
            "min_focal_length_ratio": 0.1,
            "max_focal_length_ratio": 10,
            "max_extra_param": 1e15,
        }

        sfm_dir = self.outputs / "sfm_{}".format(key)
        sfm_pairs = self.outputs / "pairs-sfm.txt"
        sfm_dir.mkdir(exist_ok=True, parents=True)

        # extract features
        retrieval_path = extract_features.main(
            retrieval_conf, temp_images, self.outputs
        )
        pairs_from_retrieval.main(retrieval_path, sfm_pairs, num_matched=top_k)

        feature_path = extract_features.main(
            feature_conf, temp_images, self.outputs
        )
        # match features
        match_path = match_features.main(
            match_conf, sfm_pairs, feature_conf["output"], self.outputs
        )
        # reconstruction
        already_sfm = False
        if sfm_dir.exists():
            try:
                model = pycolmap.Reconstruction(str(sfm_dir))
                already_sfm = True
            except ValueError:
                logger.info(f"sfm_dir not exists model: {sfm_dir}")
        if not already_sfm:
            model = reconstruction.main(
                sfm_dir,
                temp_images,
                sfm_pairs,
                feature_path,
                match_path,
                mapper_options=mapper_options,
            )

        vertices = []
        for point3D_id, point3D in model.points3D.items():
            vertices.append([point3D.xyz, point3D.color])

        model_3d = sfm_dir / "points3D.obj"
        with open(model_3d, "w") as f:
            for p, c in vertices:
                # Write vertex position
                f.write("v {} {} {}\n".format(p[0], p[1], p[2]))
                # Write vertex normal (color)
                f.write(
                    "vn {} {} {}\n".format(
                        c[0] / 255.0, c[1] / 255.0, c[2] / 255.0
                    )
                )
        viz_2d = visualization.visualize_sfm_2d(
            model, temp_images, color_by="visibility", n=2, dpi=300
        )

        return model_3d, fig2im(viz_2d) / 255.0