Realcat commited on
Commit
7ef7e3c
·
1 Parent(s): 1f2ecd8

update: gradio to 5.1.0

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🤗
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.28.3
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
api/__init__.py ADDED
File without changes
api/client.py CHANGED
@@ -1,18 +1,102 @@
1
  import argparse
 
 
2
  import pickle
3
  import time
4
- from typing import Dict
5
 
 
6
  import numpy as np
7
  import requests
8
- from loguru import logger
9
 
10
- API_URL_MATCH = "http://127.0.0.1:8001/v1/match"
11
- API_URL_EXTRACT = "http://127.0.0.1:8001/v1/extract"
12
- API_URL_EXTRACT_V2 = "http://127.0.0.1:8001/v2/extract"
13
 
 
14
 
15
- def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
  Send a request to the API to generate a match between two images.
18
 
@@ -28,6 +112,7 @@ def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
28
  """
29
  files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
30
  try:
 
31
  response = requests.post(API_URL_MATCH, files=files)
32
  pred = {}
33
  if response.status_code == 200:
@@ -44,68 +129,56 @@ def send_generate_request(path0: str, path1: str) -> Dict[str, np.ndarray]:
44
  return pred
45
 
46
 
47
- def send_generate_request1(path0: str) -> Dict[str, np.ndarray]:
 
 
48
  """
49
  Send a request to the API to extract features from an image.
50
 
51
  Args:
52
- path0 (str): The path to the image.
53
 
54
  Returns:
55
- Dict[str, np.ndarray]: A dictionary containing the extracted features.
56
- The keys are "keypoints", "descriptors", and "scores", and the
57
- values are ndarrays of shape (N, 2), (N, 128), and (N,),
58
- respectively.
59
  """
60
- files = {"image": open(path0, "rb")}
61
- try:
62
- response = requests.post(API_URL_EXTRACT, files=files)
63
- pred: Dict[str, np.ndarray] = {}
64
- if response.status_code == 200:
65
- pred = response.json()
66
- for key in list(pred.keys()):
67
- pred[key] = np.array(pred[key])
68
- else:
69
- print(
70
- f"Error: Response code {response.status_code} - {response.text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
- finally:
73
- files["image"].close()
74
- return pred
75
-
76
 
77
- def send_generate_request2(image_path: str) -> Dict[str, np.ndarray]:
78
- """
79
- Send a request to the API to extract features from an image.
80
-
81
- Args:
82
- image_path (str): The path to the image.
83
 
84
- Returns:
85
- Dict[str, np.ndarray]: A dictionary containing the extracted features.
86
- The keys are "keypoints", "descriptors", and "scores", and the
87
- values are ndarrays of shape (N, 2), (N, 128), and (N,), respectively.
88
- """
89
- data = {
90
- "image_path": image_path,
91
- "max_keypoints": 1024,
92
- "reference_points": [[0.0, 0.0], [1.0, 1.0]],
93
- }
94
- pred = {}
95
  try:
96
- response = requests.post(API_URL_EXTRACT_V2, json=data)
97
- pred: Dict[str, np.ndarray] = {}
98
- if response.status_code == 200:
99
- pred = response.json()
100
- for key in list(pred.keys()):
101
- pred[key] = np.array(pred[key])
102
- else:
103
- print(
104
- f"Error: Response code {response.status_code} - {response.text}"
105
- )
106
  except Exception as e:
107
  print(f"An error occurred: {e}")
108
- return pred
109
 
110
 
111
  if __name__ == "__main__":
@@ -116,32 +189,37 @@ if __name__ == "__main__":
116
  "--image0",
117
  required=False,
118
  help="Path for the file's melody",
119
- default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg",
120
  )
121
  parser.add_argument(
122
  "--image1",
123
  required=False,
124
  help="Path for the file's melody",
125
- default="../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
126
  )
127
  args = parser.parse_args()
128
- for i in range(10):
129
- t1 = time.time()
130
- preds = send_generate_request(args.image0, args.image1)
131
- t2 = time.time()
132
- logger.info(f"Time cost1: {(t2 - t1)} seconds")
133
-
134
- for i in range(10):
135
- t1 = time.time()
136
- preds = send_generate_request1(args.image0)
137
- t2 = time.time()
138
- logger.info(f"Time cost2: {(t2 - t1)} seconds")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  for i in range(10):
141
  t1 = time.time()
142
- preds = send_generate_request2(args.image0)
143
  t2 = time.time()
144
- logger.info(f"Time cost2: {(t2 - t1)} seconds")
145
 
 
146
  with open("preds.pkl", "wb") as f:
147
  pickle.dump(preds, f)
 
1
  import argparse
2
+ import base64
3
+ import os
4
  import pickle
5
  import time
6
+ from typing import Dict, List
7
 
8
+ import cv2
9
  import numpy as np
10
  import requests
 
11
 
12
+ ENDPOINT = "http://127.0.0.1:8001"
13
+ if "REMOTE_URL_RAILWAY" in os.environ:
14
+ ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
15
 
16
+ print(f"API ENDPOINT: {ENDPOINT}")
17
 
18
+ API_VERSION = f"{ENDPOINT}/version"
19
+ API_URL_MATCH = f"{ENDPOINT}/v1/match"
20
+ API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
21
+
22
+
23
+ def read_image(path: str) -> str:
24
+ """
25
+ Read an image from a file, encode it as a JPEG and then as a base64 string.
26
+
27
+ Args:
28
+ path (str): The path to the image to read.
29
+
30
+ Returns:
31
+ str: The base64 encoded image.
32
+ """
33
+ # Read the image from the file
34
+ img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
35
+
36
+ # Encode the image as a png, NO COMPRESSION!!!
37
+ retval, buffer = cv2.imencode(".png", img)
38
+
39
+ # Encode the JPEG as a base64 string
40
+ b64img = base64.b64encode(buffer).decode("utf-8")
41
+
42
+ return b64img
43
+
44
+
45
+ def do_api_requests(url=API_URL_EXTRACT, **kwargs):
46
+ """
47
+ Helper function to send an API request to the image matching service.
48
+
49
+ Args:
50
+ url (str): The URL of the API endpoint to use. Defaults to the
51
+ feature extraction endpoint.
52
+ **kwargs: Additional keyword arguments to pass to the API.
53
+
54
+ Returns:
55
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
56
+ extracted features. The keys are "keypoints", "descriptors", and
57
+ "scores", and the values are ndarrays of shape (N, 2), (N, ?),
58
+ and (N,), respectively.
59
+ """
60
+ # Set up the request body
61
+ reqbody = {
62
+ # List of image data base64 encoded
63
+ "data": [],
64
+ # List of maximum number of keypoints to extract from each image
65
+ "max_keypoints": [100, 100],
66
+ # List of timestamps for each image (not used?)
67
+ "timestamps": ["0", "1"],
68
+ # Whether to convert the images to grayscale
69
+ "grayscale": 0,
70
+ # List of image height and width
71
+ "image_hw": [[640, 480], [320, 240]],
72
+ # Type of feature to extract
73
+ "feature_type": 0,
74
+ # List of rotation angles for each image
75
+ "rotates": [0.0, 0.0],
76
+ # List of scale factors for each image
77
+ "scales": [1.0, 1.0],
78
+ # List of reference points for each image (not used)
79
+ "reference_points": [[640, 480], [320, 240]],
80
+ # Whether to binarize the descriptors
81
+ "binarize": True,
82
+ }
83
+ # Update the request body with the additional keyword arguments
84
+ reqbody.update(kwargs)
85
+ try:
86
+ # Send the request
87
+ r = requests.post(url, json=reqbody)
88
+ if r.status_code == 200:
89
+ # Return the response
90
+ return r.json()
91
+ else:
92
+ # Print an error message if the response code is not 200
93
+ print(f"Error: Response code {r.status_code} - {r.text}")
94
+ except Exception as e:
95
+ # Print an error message if an exception occurs
96
+ print(f"An error occurred: {e}")
97
+
98
+
99
+ def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
100
  """
101
  Send a request to the API to generate a match between two images.
102
 
 
112
  """
113
  files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
114
  try:
115
+ # TODO: replace files with post json
116
  response = requests.post(API_URL_MATCH, files=files)
117
  pred = {}
118
  if response.status_code == 200:
 
129
  return pred
130
 
131
 
132
+ def send_request_extract(
133
+ input_images: str, viz: bool = False
134
+ ) -> List[Dict[str, np.ndarray]]:
135
  """
136
  Send a request to the API to extract features from an image.
137
 
138
  Args:
139
+ input_images (str): The path to the image.
140
 
141
  Returns:
142
+ List[Dict[str, np.ndarray]]: A list of dictionaries containing the
143
+ extracted features. The keys are "keypoints", "descriptors", and
144
+ "scores", and the values are ndarrays of shape (N, 2), (N, 128),
145
+ and (N,), respectively.
146
  """
147
+ image_data = read_image(input_images)
148
+ inputs = {
149
+ "data": [image_data],
150
+ }
151
+ response = do_api_requests(
152
+ url=API_URL_EXTRACT,
153
+ **inputs,
154
+ )
155
+ print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
156
+
157
+ # draw matching, debug only
158
+ if viz:
159
+ from hloc.utils.viz import plot_keypoints
160
+ from ui.viz import fig2im, plot_images
161
+
162
+ kpts = np.array(response[0]["keypoints_orig"])
163
+ if "image_orig" in response[0].keys():
164
+ img_orig = np.array(["image_orig"])
165
+
166
+ output_keypoints = plot_images([img_orig], titles="titles", dpi=300)
167
+ plot_keypoints([kpts])
168
+ output_keypoints = fig2im(output_keypoints)
169
+ cv2.imwrite(
170
+ "demo_match.jpg",
171
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
172
  )
173
+ return response
 
 
 
174
 
 
 
 
 
 
 
175
 
176
+ def get_api_version():
 
 
 
 
 
 
 
 
 
 
177
  try:
178
+ response = requests.get(API_VERSION).json()
179
+ print("API VERSION: {}".format(response["version"]))
 
 
 
 
 
 
 
 
180
  except Exception as e:
181
  print(f"An error occurred: {e}")
 
182
 
183
 
184
  if __name__ == "__main__":
 
189
  "--image0",
190
  required=False,
191
  help="Path for the file's melody",
192
+ default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg",
193
  )
194
  parser.add_argument(
195
  "--image1",
196
  required=False,
197
  help="Path for the file's melody",
198
+ default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
199
  )
200
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # get api version
203
+ get_api_version()
204
+
205
+ # request match
206
+ # for i in range(10):
207
+ # t1 = time.time()
208
+ # preds = send_request_match(args.image0, args.image1)
209
+ # t2 = time.time()
210
+ # print(
211
+ # "Time cost1: {} seconds, matched: {}".format(
212
+ # (t2 - t1), len(preds["mmkeypoints0_orig"])
213
+ # )
214
+ # )
215
+
216
+ # request extract
217
  for i in range(10):
218
  t1 = time.time()
219
+ preds = send_request_extract(args.image0)
220
  t2 = time.time()
221
+ print(f"Time cost2: {(t2 - t1)} seconds")
222
 
223
+ # dump preds
224
  with open("preds.pkl", "wb") as f:
225
  pickle.dump(preds, f)
api/server.py CHANGED
@@ -1,73 +1,435 @@
1
  # server.py
 
 
2
  import sys
 
3
  from pathlib import Path
4
- from typing import Union
5
 
 
 
6
  import numpy as np
 
7
  import uvicorn
8
  from fastapi import FastAPI, File, UploadFile
 
9
  from fastapi.responses import JSONResponse
10
  from PIL import Image
11
 
12
- sys.path.append("..")
13
- from pydantic import BaseModel
14
 
15
- from ui.api import ImageMatchingAPI
16
- from ui.utils import DEVICE
 
 
 
 
17
 
 
18
 
19
- class ImageInfo(BaseModel):
20
- image_path: str
21
- max_keypoints: int
22
- reference_points: list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class ImageMatchingService:
26
  def __init__(self, conf: dict, device: str):
 
27
  self.api = ImageMatchingAPI(conf=conf, device=device)
28
  self.app = FastAPI()
29
  self.register_routes()
30
 
31
  def register_routes(self):
 
 
 
 
 
32
  @self.app.post("/v1/match")
33
  async def match(
34
  image0: UploadFile = File(...), image1: UploadFile = File(...)
35
  ):
 
 
 
 
 
 
 
 
 
 
 
36
  try:
 
37
  image0_array = self.load_image(image0)
38
  image1_array = self.load_image(image1)
39
 
 
40
  output = self.api(image0_array, image1_array)
41
 
 
42
  skip_keys = ["image0_orig", "image1_orig"]
43
- pred = self.filter_output(output, skip_keys)
44
 
 
 
 
 
45
  return JSONResponse(content=pred)
46
  except Exception as e:
 
47
  return JSONResponse(content={"error": str(e)}, status_code=500)
48
 
49
  @self.app.post("/v1/extract")
50
- async def extract(image: UploadFile = File(...)):
51
- try:
52
- image_array = self.load_image(image)
53
- output = self.api.extract(image_array)
54
- skip_keys = ["descriptors", "image", "image_orig"]
55
- pred = self.filter_output(output, skip_keys)
56
- return JSONResponse(content=pred)
57
- except Exception as e:
58
- return JSONResponse(content={"error": str(e)}, status_code=500)
59
 
60
- @self.app.post("/v2/extract")
61
- async def extract_v2(image_path: ImageInfo):
62
- img_path = image_path.image_path
63
  try:
64
- safe_path = Path(img_path).resolve(strict=False)
65
- image_array = self.load_image(str(safe_path))
66
- output = self.api.extract(image_array)
67
- skip_keys = ["descriptors", "image", "image_orig"]
68
- pred = self.filter_output(output, skip_keys)
69
- return JSONResponse(content=pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  except Exception as e:
 
71
  return JSONResponse(content={"error": str(e)}, status_code=500)
72
 
73
  def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
@@ -88,7 +450,9 @@ class ImageMatchingService:
88
  image_array = np.array(img)
89
  return image_array
90
 
91
- def filter_output(self, output: dict, skip_keys: list) -> dict:
 
 
92
  pred = {}
93
  for key, value in output.items():
94
  if key in skip_keys:
 
1
  # server.py
2
+ import base64
3
+ import io
4
  import sys
5
+ import warnings
6
  from pathlib import Path
7
+ from typing import Any, Dict, Optional, Union
8
 
9
+ import cv2
10
+ import matplotlib.pyplot as plt
11
  import numpy as np
12
+ import torch
13
  import uvicorn
14
  from fastapi import FastAPI, File, UploadFile
15
+ from fastapi.exceptions import HTTPException
16
  from fastapi.responses import JSONResponse
17
  from PIL import Image
18
 
19
+ sys.path.append(str(Path(__file__).parents[1]))
 
20
 
21
+ from api.types import ImagesInput
22
+ from hloc import DEVICE, extract_features, logger, match_dense, match_features
23
+ from hloc.utils.viz import add_text, plot_keypoints
24
+ from ui import get_version
25
+ from ui.utils import filter_matches, get_feature_model, get_model
26
+ from ui.viz import display_matches, fig2im, plot_images
27
 
28
+ warnings.simplefilter("ignore")
29
 
30
+
31
+ def decode_base64_to_image(encoding):
32
+ if encoding.startswith("data:image/"):
33
+ encoding = encoding.split(";")[1].split(",")[1]
34
+ try:
35
+ image = Image.open(io.BytesIO(base64.b64decode(encoding)))
36
+ return image
37
+ except Exception as e:
38
+ logger.warning(f"API cannot decode image: {e}")
39
+ raise HTTPException(
40
+ status_code=500, detail="Invalid encoded image"
41
+ ) from e
42
+
43
+
44
+ def to_base64_nparray(encoding: str) -> np.ndarray:
45
+ return np.array(decode_base64_to_image(encoding)).astype("uint8")
46
+
47
+
48
+ class ImageMatchingAPI(torch.nn.Module):
49
+ default_conf = {
50
+ "ransac": {
51
+ "enable": True,
52
+ "estimator": "poselib",
53
+ "geometry": "homography",
54
+ "method": "RANSAC",
55
+ "reproj_threshold": 3,
56
+ "confidence": 0.9999,
57
+ "max_iter": 10000,
58
+ },
59
+ }
60
+
61
+ def __init__(
62
+ self,
63
+ conf: dict = {},
64
+ device: str = "cpu",
65
+ detect_threshold: float = 0.015,
66
+ max_keypoints: int = 1024,
67
+ match_threshold: float = 0.2,
68
+ ) -> None:
69
+ """
70
+ Initializes an instance of the ImageMatchingAPI class.
71
+
72
+ Args:
73
+ conf (dict): A dictionary containing the configuration parameters.
74
+ device (str, optional): The device to use for computation. Defaults to "cpu".
75
+ detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
76
+ max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
77
+ match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
78
+
79
+ Returns:
80
+ None
81
+ """
82
+ super().__init__()
83
+ self.device = device
84
+ self.conf = {**self.default_conf, **conf}
85
+ self._updata_config(detect_threshold, max_keypoints, match_threshold)
86
+ self._init_models()
87
+ if device == "cuda":
88
+ memory_allocated = torch.cuda.memory_allocated(device)
89
+ memory_reserved = torch.cuda.memory_reserved(device)
90
+ logger.info(
91
+ f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB"
92
+ )
93
+ logger.info(
94
+ f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB"
95
+ )
96
+ self.pred = None
97
+
98
+ def parse_match_config(self, conf):
99
+ if conf["dense"]:
100
+ return {
101
+ **conf,
102
+ "matcher": match_dense.confs.get(
103
+ conf["matcher"]["model"]["name"]
104
+ ),
105
+ "dense": True,
106
+ }
107
+ else:
108
+ return {
109
+ **conf,
110
+ "feature": extract_features.confs.get(
111
+ conf["feature"]["model"]["name"]
112
+ ),
113
+ "matcher": match_features.confs.get(
114
+ conf["matcher"]["model"]["name"]
115
+ ),
116
+ "dense": False,
117
+ }
118
+
119
+ def _updata_config(
120
+ self,
121
+ detect_threshold: float = 0.015,
122
+ max_keypoints: int = 1024,
123
+ match_threshold: float = 0.2,
124
+ ):
125
+ self.dense = self.conf["dense"]
126
+ if self.conf["dense"]:
127
+ try:
128
+ self.conf["matcher"]["model"][
129
+ "match_threshold"
130
+ ] = match_threshold
131
+ except TypeError as e:
132
+ logger.error(e)
133
+ else:
134
+ self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
135
+ self.conf["feature"]["model"][
136
+ "keypoint_threshold"
137
+ ] = detect_threshold
138
+ self.extract_conf = self.conf["feature"]
139
+
140
+ self.match_conf = self.conf["matcher"]
141
+
142
+ def _init_models(self):
143
+ # initialize matcher
144
+ self.matcher = get_model(self.match_conf)
145
+ # initialize extractor
146
+ if self.dense:
147
+ self.extractor = None
148
+ else:
149
+ self.extractor = get_feature_model(self.conf["feature"])
150
+
151
+ def _forward(self, img0, img1):
152
+ if self.dense:
153
+ pred = match_dense.match_images(
154
+ self.matcher,
155
+ img0,
156
+ img1,
157
+ self.match_conf["preprocessing"],
158
+ device=self.device,
159
+ )
160
+ last_fixed = "{}".format( # noqa: F841
161
+ self.match_conf["model"]["name"]
162
+ )
163
+ else:
164
+ pred0 = extract_features.extract(
165
+ self.extractor, img0, self.extract_conf["preprocessing"]
166
+ )
167
+ pred1 = extract_features.extract(
168
+ self.extractor, img1, self.extract_conf["preprocessing"]
169
+ )
170
+ pred = match_features.match_images(self.matcher, pred0, pred1)
171
+ return pred
172
+
173
+ @torch.inference_mode()
174
+ def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
175
+ """Extract features from a single image.
176
+
177
+ Args:
178
+ img0 (np.ndarray): image
179
+
180
+ Returns:
181
+ Dict[str, np.ndarray]: feature dict
182
+ """
183
+
184
+ # setting prams
185
+ self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512)
186
+ self.extractor.conf["keypoint_threshold"] = kwargs.get(
187
+ "keypoint_threshold", 0.0
188
+ )
189
+
190
+ pred = extract_features.extract(
191
+ self.extractor, img0, self.extract_conf["preprocessing"]
192
+ )
193
+ pred = {
194
+ k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
195
+ for k, v in pred.items()
196
+ }
197
+ # back to origin scale
198
+ s0 = pred["original_size"] / pred["size"]
199
+ pred["keypoints_orig"] = (
200
+ match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
201
+ )
202
+ # TODO: rotate back
203
+
204
+ binarize = kwargs.get("binarize", False)
205
+ if binarize:
206
+ assert "descriptors" in pred
207
+ pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8)
208
+ pred["descriptors"] = pred["descriptors"].T # N x DIM
209
+ return pred
210
+
211
+ @torch.inference_mode()
212
+ def forward(
213
+ self,
214
+ img0: np.ndarray,
215
+ img1: np.ndarray,
216
+ ) -> Dict[str, np.ndarray]:
217
+ """
218
+ Forward pass of the image matching API.
219
+
220
+ Args:
221
+ img0: A 3D NumPy array of shape (H, W, C) representing the first image.
222
+ Values are in the range [0, 1] and are in RGB mode.
223
+ img1: A 3D NumPy array of shape (H, W, C) representing the second image.
224
+ Values are in the range [0, 1] and are in RGB mode.
225
+
226
+ Returns:
227
+ A dictionary containing the following keys:
228
+ - image0_orig: The original image 0.
229
+ - image1_orig: The original image 1.
230
+ - keypoints0_orig: The keypoints detected in image 0.
231
+ - keypoints1_orig: The keypoints detected in image 1.
232
+ - mkeypoints0_orig: The raw matches between image 0 and image 1.
233
+ - mkeypoints1_orig: The raw matches between image 1 and image 0.
234
+ - mmkeypoints0_orig: The RANSAC inliers in image 0.
235
+ - mmkeypoints1_orig: The RANSAC inliers in image 1.
236
+ - mconf: The confidence scores for the raw matches.
237
+ - mmconf: The confidence scores for the RANSAC inliers.
238
+ """
239
+ # Take as input a pair of images (not a batch)
240
+ assert isinstance(img0, np.ndarray)
241
+ assert isinstance(img1, np.ndarray)
242
+ self.pred = self._forward(img0, img1)
243
+ if self.conf["ransac"]["enable"]:
244
+ self.pred = self._geometry_check(self.pred)
245
+ return self.pred
246
+
247
+ def _geometry_check(
248
+ self,
249
+ pred: Dict[str, Any],
250
+ ) -> Dict[str, Any]:
251
+ """
252
+ Filter matches using RANSAC. If keypoints are available, filter by keypoints.
253
+ If lines are available, filter by lines. If both keypoints and lines are
254
+ available, filter by keypoints.
255
+
256
+ Args:
257
+ pred (Dict[str, Any]): dict of matches, including original keypoints.
258
+ See :func:`filter_matches` for the expected keys.
259
+
260
+ Returns:
261
+ Dict[str, Any]: filtered matches
262
+ """
263
+ pred = filter_matches(
264
+ pred,
265
+ ransac_method=self.conf["ransac"]["method"],
266
+ ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
267
+ ransac_confidence=self.conf["ransac"]["confidence"],
268
+ ransac_max_iter=self.conf["ransac"]["max_iter"],
269
+ )
270
+ return pred
271
+
272
+ def visualize(
273
+ self,
274
+ log_path: Optional[Path] = None,
275
+ ) -> None:
276
+ """
277
+ Visualize the matches.
278
+
279
+ Args:
280
+ log_path (Path, optional): The directory to save the images. Defaults to None.
281
+
282
+ Returns:
283
+ None
284
+ """
285
+ if self.conf["dense"]:
286
+ postfix = str(self.conf["matcher"]["model"]["name"])
287
+ else:
288
+ postfix = "{}_{}".format(
289
+ str(self.conf["feature"]["model"]["name"]),
290
+ str(self.conf["matcher"]["model"]["name"]),
291
+ )
292
+ titles = [
293
+ "Image 0 - Keypoints",
294
+ "Image 1 - Keypoints",
295
+ ]
296
+ pred: Dict[str, Any] = self.pred
297
+ image0: np.ndarray = pred["image0_orig"]
298
+ image1: np.ndarray = pred["image1_orig"]
299
+ output_keypoints: np.ndarray = plot_images(
300
+ [image0, image1], titles=titles, dpi=300
301
+ )
302
+ if (
303
+ "keypoints0_orig" in pred.keys()
304
+ and "keypoints1_orig" in pred.keys()
305
+ ):
306
+ plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
307
+ text: str = (
308
+ f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
309
+ + f"# keypoints1: {len(pred['keypoints1_orig'])}"
310
+ )
311
+ add_text(0, text, fs=15)
312
+ output_keypoints = fig2im(output_keypoints)
313
+ # plot images with raw matches
314
+ titles = [
315
+ "Image 0 - Raw matched keypoints",
316
+ "Image 1 - Raw matched keypoints",
317
+ ]
318
+ output_matches_raw, num_matches_raw = display_matches(
319
+ pred, titles=titles, tag="KPTS_RAW"
320
+ )
321
+ # plot images with ransac matches
322
+ titles = [
323
+ "Image 0 - Ransac matched keypoints",
324
+ "Image 1 - Ransac matched keypoints",
325
+ ]
326
+ output_matches_ransac, num_matches_ransac = display_matches(
327
+ pred, titles=titles, tag="KPTS_RANSAC"
328
+ )
329
+ if log_path is not None:
330
+ img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
331
+ img_matches_raw_path: Path = (
332
+ log_path / f"img_matches_raw_{postfix}.png"
333
+ )
334
+ img_matches_ransac_path: Path = (
335
+ log_path / f"img_matches_ransac_{postfix}.png"
336
+ )
337
+ cv2.imwrite(
338
+ str(img_keypoints_path),
339
+ output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
340
+ )
341
+ cv2.imwrite(
342
+ str(img_matches_raw_path),
343
+ output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
344
+ )
345
+ cv2.imwrite(
346
+ str(img_matches_ransac_path),
347
+ output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
348
+ )
349
+ plt.close("all")
350
 
351
 
352
  class ImageMatchingService:
353
  def __init__(self, conf: dict, device: str):
354
+ self.conf = conf
355
  self.api = ImageMatchingAPI(conf=conf, device=device)
356
  self.app = FastAPI()
357
  self.register_routes()
358
 
359
  def register_routes(self):
360
+
361
+ @self.app.get("/version")
362
+ async def version():
363
+ return {"version": get_version()}
364
+
365
  @self.app.post("/v1/match")
366
  async def match(
367
  image0: UploadFile = File(...), image1: UploadFile = File(...)
368
  ):
369
+ """
370
+ Handle the image matching request and return the processed result.
371
+
372
+ Args:
373
+ image0 (UploadFile): The first image file for matching.
374
+ image1 (UploadFile): The second image file for matching.
375
+
376
+ Returns:
377
+ JSONResponse: A JSON response containing the filtered match results
378
+ or an error message in case of failure.
379
+ """
380
  try:
381
+ # Load the images from the uploaded files
382
  image0_array = self.load_image(image0)
383
  image1_array = self.load_image(image1)
384
 
385
+ # Perform image matching using the API
386
  output = self.api(image0_array, image1_array)
387
 
388
+ # Keys to skip in the output
389
  skip_keys = ["image0_orig", "image1_orig"]
 
390
 
391
+ # Postprocess the output to filter unwanted data
392
+ pred = self.postprocess(output, skip_keys)
393
+
394
+ # Return the filtered prediction as a JSON response
395
  return JSONResponse(content=pred)
396
  except Exception as e:
397
+ # Return an error message with status code 500 in case of exception
398
  return JSONResponse(content={"error": str(e)}, status_code=500)
399
 
400
  @self.app.post("/v1/extract")
401
+ async def extract(input_info: ImagesInput):
402
+ """
403
+ Extract keypoints and descriptors from images.
404
+
405
+ Args:
406
+ input_info: An object containing the image data and options.
 
 
 
407
 
408
+ Returns:
409
+ A list of dictionaries containing the keypoints and descriptors.
410
+ """
411
  try:
412
+ preds = []
413
+ for i, input_image in enumerate(input_info.data):
414
+ # Load the image from the input data
415
+ image_array = to_base64_nparray(input_image)
416
+ # Extract keypoints and descriptors
417
+ output = self.api.extract(
418
+ image_array,
419
+ max_keypoints=input_info.max_keypoints[i],
420
+ binarize=input_info.binarize,
421
+ )
422
+ # Do not return the original image and image_orig
423
+ # skip_keys = ["image", "image_orig"]
424
+ skip_keys = []
425
+
426
+ # Postprocess the output
427
+ pred = self.postprocess(output, skip_keys)
428
+ preds.append(pred)
429
+ # Return the list of extracted features
430
+ return JSONResponse(content=preds)
431
  except Exception as e:
432
+ # Return an error message if an exception occurs
433
  return JSONResponse(content={"error": str(e)}, status_code=500)
434
 
435
  def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
 
450
  image_array = np.array(img)
451
  return image_array
452
 
453
+ def postprocess(
454
+ self, output: dict, skip_keys: list, binarize: bool = True
455
+ ) -> dict:
456
  pred = {}
457
  for key, value in output.items():
458
  if key in skip_keys:
api/test/CMakeLists.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.10)
2
+ project(imatchui)
3
+
4
+ set(OpenCV_DIR /usr/include/opencv4)
5
+ find_package(OpenCV REQUIRED)
6
+
7
+ find_package(Boost REQUIRED COMPONENTS system)
8
+ if(Boost_FOUND)
9
+ include_directories(${Boost_INCLUDE_DIRS})
10
+ endif()
11
+
12
+ add_executable(client client.cpp)
13
+
14
+ target_include_directories(client PRIVATE ${Boost_LIBRARIES} ${OpenCV_INCLUDE_DIRS})
15
+
16
+ target_link_libraries(client PRIVATE curl jsoncpp b64 ${OpenCV_LIBS})
api/test/build_and_run.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # g++ main.cpp -I/usr/include/opencv4 -lcurl -ljsoncpp -lb64 -lopencv_core -lopencv_imgcodecs -o main
2
+ # sudo apt-get update
3
+ # sudo apt-get install libboost-all-dev -y
4
+ # sudo apt-get install libcurl4-openssl-dev libjsoncpp-dev libb64-dev libopencv-dev -y
5
+
6
+ cd build
7
+ cmake ..
8
+ make -j12
9
+
10
+ echo " ======== RUN DEMO ========"
11
+
12
+ ./client
13
+
14
+ echo " ======== END DEMO ========"
15
+
16
+ cd ..
api/test/client.cpp ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <curl/curl.h>
2
+ #include <opencv2/opencv.hpp>
3
+ #include "helper.h"
4
+
5
+ int main() {
6
+ std::string img_path = "../../../datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg";
7
+ cv::Mat original_img = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
8
+
9
+ if (original_img.empty()) {
10
+ throw std::runtime_error("Failed to decode image");
11
+ }
12
+
13
+ // Convert the image to Base64
14
+ std::string base64_img = image_to_base64(original_img);
15
+
16
+ // Convert the Base64 back to an image
17
+ cv::Mat decoded_img = base64_to_image(base64_img);
18
+ cv::imwrite("decoded_image.jpg", decoded_img);
19
+ cv::imwrite("original_img.jpg", original_img);
20
+
21
+ // The images should be identical
22
+ if (cv::countNonZero(original_img != decoded_img) != 0) {
23
+ std::cerr << "The images are not identical" << std::endl;
24
+ return -1;
25
+ } else {
26
+ std::cout << "The images are identical!" << std::endl;
27
+ }
28
+
29
+ // construct params
30
+ APIParams params{
31
+ .data = {base64_img},
32
+ .max_keypoints = {100, 100},
33
+ .timestamps = {"0", "1"},
34
+ .grayscale = {0},
35
+ .image_hw = {{480, 640}, {240, 320}},
36
+ .feature_type = 0,
37
+ .rotates = {0.0f, 0.0f},
38
+ .scales = {1.0f, 1.0f},
39
+ .reference_points = {
40
+ {1.23e+2f, 1.2e+1f},
41
+ {5.0e-1f, 3.0e-1f},
42
+ {2.3e+2f, 2.2e+1f},
43
+ {6.0e-1f, 4.0e-1f}
44
+ },
45
+ .binarize = {1}
46
+ };
47
+
48
+ KeyPointResults kpts_results;
49
+
50
+ // Convert the parameters to JSON
51
+ Json::Value jsonData = paramsToJson(params);
52
+ std::string url = "http://127.0.0.1:8001/v1/extract";
53
+ Json::StreamWriterBuilder writer;
54
+ std::string output = Json::writeString(writer, jsonData);
55
+
56
+ CURL* curl;
57
+ CURLcode res;
58
+ std::string readBuffer;
59
+
60
+ curl_global_init(CURL_GLOBAL_DEFAULT);
61
+ curl = curl_easy_init();
62
+ if (curl) {
63
+ struct curl_slist* hs = NULL;
64
+ hs = curl_slist_append(hs, "Content-Type: application/json");
65
+ curl_easy_setopt(curl, CURLOPT_HTTPHEADER, hs);
66
+ curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
67
+ curl_easy_setopt(curl, CURLOPT_POSTFIELDS, output.c_str());
68
+ curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
69
+ curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer);
70
+ res = curl_easy_perform(curl);
71
+
72
+ if (res != CURLE_OK)
73
+ fprintf(stderr, "curl_easy_perform() failed: %s\n",
74
+ curl_easy_strerror(res));
75
+ else {
76
+ // std::cout << "Response from server: " << readBuffer << std::endl;
77
+ kpts_results = decode_response(readBuffer);
78
+ }
79
+ curl_easy_cleanup(curl);
80
+ }
81
+ curl_global_cleanup();
82
+
83
+ return 0;
84
+ }
api/test/helper.h ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <sstream>
3
+ #include <fstream>
4
+ #include <vector>
5
+ #include <b64/encode.h>
6
+ #include <jsoncpp/json/json.h>
7
+ #include <opencv2/opencv.hpp>
8
+
9
+ // base64 to image
10
+ #include <boost/archive/iterators/binary_from_base64.hpp>
11
+ #include <boost/archive/iterators/transform_width.hpp>
12
+ #include <boost/archive/iterators/base64_from_binary.hpp>
13
+
14
+ /// Parameters used in the API
15
+ struct APIParams {
16
+ /// A list of images, base64 encoded
17
+ std::vector<std::string> data;
18
+
19
+ /// The maximum number of keypoints to detect for each image
20
+ std::vector<int> max_keypoints;
21
+
22
+ /// The timestamps of the images
23
+ std::vector<std::string> timestamps;
24
+
25
+ /// Whether to convert the images to grayscale
26
+ bool grayscale;
27
+
28
+ /// The height and width of each image
29
+ std::vector<std::vector<int>> image_hw;
30
+
31
+ /// The type of feature detector to use
32
+ int feature_type;
33
+
34
+ /// The rotations of the images
35
+ std::vector<double> rotates;
36
+
37
+ /// The scales of the images
38
+ std::vector<double> scales;
39
+
40
+ /// The reference points of the images
41
+ std::vector<std::vector<float>> reference_points;
42
+
43
+ /// Whether to binarize the descriptors
44
+ bool binarize;
45
+ };
46
+
47
+ /**
48
+ * @brief Contains the results of a keypoint detector.
49
+ *
50
+ * @details Stores the keypoints and descriptors for each image.
51
+ */
52
+ class KeyPointResults {
53
+ public:
54
+ KeyPointResults() {}
55
+
56
+ /**
57
+ * @brief Constructor.
58
+ *
59
+ * @param kp The keypoints for each image.
60
+ */
61
+ KeyPointResults(const std::vector<std::vector<cv::KeyPoint>>& kp,
62
+ const std::vector<cv::Mat>& desc)
63
+ : keypoints(kp), descriptors(desc) {}
64
+
65
+ /**
66
+ * @brief Append keypoints to the result.
67
+ *
68
+ * @param kpts The keypoints to append.
69
+ */
70
+ inline void append_keypoints(std::vector<cv::KeyPoint>& kpts) {
71
+ keypoints.emplace_back(kpts);
72
+ }
73
+
74
+ /**
75
+ * @brief Append descriptors to the result.
76
+ *
77
+ * @param desc The descriptors to append.
78
+ */
79
+ inline void append_descriptors(cv::Mat& desc) {
80
+ descriptors.emplace_back(desc);
81
+ }
82
+
83
+ /**
84
+ * @brief Get the keypoints.
85
+ *
86
+ * @return The keypoints.
87
+ */
88
+ inline std::vector<std::vector<cv::KeyPoint>> get_keypoints() {
89
+ return keypoints;
90
+ }
91
+
92
+ /**
93
+ * @brief Get the descriptors.
94
+ *
95
+ * @return The descriptors.
96
+ */
97
+ inline std::vector<cv::Mat> get_descriptors() {
98
+ return descriptors;
99
+ }
100
+
101
+ private:
102
+ std::vector<std::vector<cv::KeyPoint>> keypoints;
103
+ std::vector<cv::Mat> descriptors;
104
+ std::vector<std::vector<float>> scores;
105
+ };
106
+
107
+
108
+ /**
109
+ * @brief Decodes a base64 encoded string.
110
+ *
111
+ * @param base64 The base64 encoded string to decode.
112
+ * @return The decoded string.
113
+ */
114
+ std::string base64_decode(const std::string& base64) {
115
+ using namespace boost::archive::iterators;
116
+ using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>;
117
+
118
+ // Find the position of the last non-whitespace character
119
+ auto end = base64.find_last_not_of(" \t\n\r");
120
+ if (end != std::string::npos) {
121
+ // Move one past the last non-whitespace character
122
+ end += 1;
123
+ }
124
+
125
+ // Decode the base64 string and return the result
126
+ return std::string(It(base64.begin()), It(base64.begin() + end));
127
+ }
128
+
129
+
130
+
131
+ /**
132
+ * @brief Decodes a base64 string into an OpenCV image
133
+ *
134
+ * @param base64 The base64 encoded string
135
+ * @return The decoded OpenCV image
136
+ */
137
+ cv::Mat base64_to_image(const std::string& base64) {
138
+ // Decode the base64 string
139
+ std::string decodedStr = base64_decode(base64);
140
+
141
+ // Decode the image
142
+ std::vector<uchar> data(decodedStr.begin(), decodedStr.end());
143
+ cv::Mat img = cv::imdecode(data, cv::IMREAD_GRAYSCALE);
144
+
145
+ // Check for errors
146
+ if (img.empty()) {
147
+ throw std::runtime_error("Failed to decode image");
148
+ }
149
+
150
+ return img;
151
+ }
152
+
153
+
154
+ /**
155
+ * @brief Encodes an OpenCV image into a base64 string
156
+ *
157
+ * This function takes an OpenCV image and encodes it into a base64 string.
158
+ * The image is first encoded as a PNG image, and then the resulting
159
+ * bytes are encoded as a base64 string.
160
+ *
161
+ * @param img The OpenCV image
162
+ * @return The base64 encoded string
163
+ *
164
+ * @throws std::runtime_error if the image is empty or encoding fails
165
+ */
166
+ std::string image_to_base64(cv::Mat &img) {
167
+ if (img.empty()) {
168
+ throw std::runtime_error("Failed to read image");
169
+ }
170
+
171
+ // Encode the image as a PNG
172
+ std::vector<uchar> buf;
173
+ if (!cv::imencode(".png", img, buf)) {
174
+ throw std::runtime_error("Failed to encode image");
175
+ }
176
+
177
+ // Encode the bytes as a base64 string
178
+ using namespace boost::archive::iterators;
179
+ using It = base64_from_binary<transform_width<std::vector<uchar>::const_iterator, 6, 8>>;
180
+ std::string base64(It(buf.begin()), It(buf.end()));
181
+
182
+ // Pad the string with '=' characters to a multiple of 4 bytes
183
+ base64.append((3 - buf.size() % 3) % 3, '=');
184
+
185
+ return base64;
186
+ }
187
+
188
+
189
+ /**
190
+ * @brief Callback function for libcurl to write data to a string
191
+ *
192
+ * This function is used as a callback for libcurl to write data to a string.
193
+ * It takes the contents, size, and nmemb as parameters, and writes the data to
194
+ * the string.
195
+ *
196
+ * @param contents The data to write
197
+ * @param size The size of the data
198
+ * @param nmemb The number of members in the data
199
+ * @param s The string to write the data to
200
+ * @return The number of bytes written
201
+ */
202
+ size_t WriteCallback(void* contents, size_t size, size_t nmemb, std::string* s) {
203
+ size_t newLength = size * nmemb;
204
+ try {
205
+ // Resize the string to fit the new data
206
+ s->resize(s->size() + newLength);
207
+ } catch (std::bad_alloc& e) {
208
+ // If there's an error allocating memory, return 0
209
+ return 0;
210
+ }
211
+
212
+ // Copy the data to the string
213
+ std::copy(static_cast<const char*>(contents),
214
+ static_cast<const char*>(contents) + newLength,
215
+ s->begin() + s->size() - newLength);
216
+ return newLength;
217
+ }
218
+
219
+ // Helper functions
220
+
221
+ /**
222
+ * @brief Helper function to convert a type to a Json::Value
223
+ *
224
+ * This function takes a value of type T and converts it to a Json::Value.
225
+ * It is used to simplify the process of converting a type to a Json::Value.
226
+ *
227
+ * @param val The value to convert
228
+ * @return The converted Json::Value
229
+ */
230
+ template <typename T>
231
+ Json::Value toJson(const T& val) {
232
+ return Json::Value(val);
233
+ }
234
+
235
+ /**
236
+ * @brief Converts a vector to a Json::Value
237
+ *
238
+ * This function takes a vector of type T and converts it to a Json::Value.
239
+ * Each element in the vector is appended to the Json::Value array.
240
+ *
241
+ * @param vec The vector to convert to Json::Value
242
+ * @return The Json::Value representing the vector
243
+ */
244
+ template <typename T>
245
+ Json::Value vectorToJson(const std::vector<T>& vec) {
246
+ Json::Value json(Json::arrayValue);
247
+ for (const auto& item : vec) {
248
+ json.append(item);
249
+ }
250
+ return json;
251
+ }
252
+
253
+ /**
254
+ * @brief Converts a nested vector to a Json::Value
255
+ *
256
+ * This function takes a nested vector of type T and converts it to a Json::Value.
257
+ * Each sub-vector is converted to a Json::Value array and appended to the main Json::Value array.
258
+ *
259
+ * @param vec The nested vector to convert to Json::Value
260
+ * @return The Json::Value representing the nested vector
261
+ */
262
+ template <typename T>
263
+ Json::Value nestedVectorToJson(const std::vector<std::vector<T>>& vec) {
264
+ Json::Value json(Json::arrayValue);
265
+ for (const auto& subVec : vec) {
266
+ json.append(vectorToJson(subVec));
267
+ }
268
+ return json;
269
+ }
270
+
271
+
272
+
273
+ /**
274
+ * @brief Converts the APIParams struct to a Json::Value
275
+ *
276
+ * This function takes an APIParams struct and converts it to a Json::Value.
277
+ * The Json::Value is a JSON object with the following fields:
278
+ * - data: a JSON array of base64 encoded images
279
+ * - max_keypoints: a JSON array of integers, max number of keypoints for each image
280
+ * - timestamps: a JSON array of timestamps, one for each image
281
+ * - grayscale: a JSON boolean, whether to convert images to grayscale
282
+ * - image_hw: a nested JSON array, each sub-array contains the height and width of an image
283
+ * - feature_type: a JSON integer, the type of feature detector to use
284
+ * - rotates: a JSON array of doubles, the rotation of each image
285
+ * - scales: a JSON array of doubles, the scale of each image
286
+ * - reference_points: a nested JSON array, each sub-array contains the reference points of an image
287
+ * - binarize: a JSON boolean, whether to binarize the descriptors
288
+ *
289
+ * @param params The APIParams struct to convert
290
+ * @return The Json::Value representing the APIParams struct
291
+ */
292
+ Json::Value paramsToJson(const APIParams& params) {
293
+ Json::Value json;
294
+ json["data"] = vectorToJson(params.data);
295
+ json["max_keypoints"] = vectorToJson(params.max_keypoints);
296
+ json["timestamps"] = vectorToJson(params.timestamps);
297
+ json["grayscale"] = toJson(params.grayscale);
298
+ json["image_hw"] = nestedVectorToJson(params.image_hw);
299
+ json["feature_type"] = toJson(params.feature_type);
300
+ json["rotates"] = vectorToJson(params.rotates);
301
+ json["scales"] = vectorToJson(params.scales);
302
+ json["reference_points"] = nestedVectorToJson(params.reference_points);
303
+ json["binarize"] = toJson(params.binarize);
304
+ return json;
305
+ }
306
+
307
+ template<typename T>
308
+ cv::Mat jsonToMat(Json::Value json) {
309
+ int rows = json.size();
310
+ int cols = json[0].size();
311
+
312
+ // Create a single array to hold all the data.
313
+ std::vector<T> data;
314
+ data.reserve(rows * cols);
315
+
316
+ for (int i = 0; i < rows; i++) {
317
+ for (int j = 0; j < cols; j++) {
318
+ data.push_back(static_cast<T>(json[i][j].asInt()));
319
+ }
320
+ }
321
+
322
+ // Create a cv::Mat object that points to the data.
323
+ cv::Mat mat(rows, cols, CV_8UC1, data.data()); // Change the type if necessary.
324
+ // cv::Mat mat(cols, rows,CV_8UC1, data.data()); // Change the type if necessary.
325
+
326
+ return mat;
327
+ }
328
+
329
+
330
+
331
+ /**
332
+ * @brief Decodes the response of the server and prints the keypoints
333
+ *
334
+ * This function takes the response of the server, a JSON string, and decodes
335
+ * it. It then prints the keypoints and draws them on the original image.
336
+ *
337
+ * @param response The response of the server
338
+ * @return The keypoints and descriptors
339
+ */
340
+ KeyPointResults decode_response(const std::string& response, bool viz=true) {
341
+ Json::CharReaderBuilder builder;
342
+ Json::CharReader* reader = builder.newCharReader();
343
+
344
+ Json::Value jsonData;
345
+ std::string errors;
346
+
347
+ // Parse the JSON response
348
+ bool parsingSuccessful = reader->parse(response.c_str(),
349
+ response.c_str() + response.size(), &jsonData, &errors);
350
+ delete reader;
351
+
352
+ if (!parsingSuccessful) {
353
+ // Handle error
354
+ std::cout << "Failed to parse the JSON, errors:" << std::endl;
355
+ std::cout << errors << std::endl;
356
+ return KeyPointResults();
357
+ }
358
+
359
+ KeyPointResults kpts_results;
360
+
361
+ // Iterate over the images
362
+ for (const auto& jsonItem : jsonData) {
363
+ auto jkeypoints = jsonItem["keypoints"];
364
+ auto jkeypoints_orig = jsonItem["keypoints_orig"];
365
+ auto jdescriptors = jsonItem["descriptors"];
366
+ auto jscores = jsonItem["scores"];
367
+ auto jimageSize = jsonItem["image_size"];
368
+ auto joriginalSize = jsonItem["original_size"];
369
+ auto jsize = jsonItem["size"];
370
+
371
+ std::vector<cv::KeyPoint> vkeypoints;
372
+ std::vector<float> vscores;
373
+
374
+ // Iterate over the keypoints
375
+ int counter = 0;
376
+ for (const auto& keypoint : jkeypoints_orig) {
377
+ if (counter < 10) {
378
+ // Print the first 10 keypoints
379
+ std::cout << keypoint[0].asFloat() << ", "
380
+ << keypoint[1].asFloat() << std::endl;
381
+ }
382
+ counter++;
383
+ // Convert the Json::Value to a cv::KeyPoint
384
+ vkeypoints.emplace_back(cv::KeyPoint(keypoint[0].asFloat(),
385
+ keypoint[1].asFloat(), 0.0));
386
+ }
387
+
388
+ if (viz && jsonItem.isMember("image_orig")) {
389
+
390
+ auto jimg_orig = jsonItem["image_orig"];
391
+ cv::Mat img = jsonToMat<uchar>(jimg_orig);
392
+ cv::imwrite("viz_image_orig.jpg", img);
393
+
394
+ // Draw keypoints on the image
395
+ cv::Mat imgWithKeypoints;
396
+ cv::drawKeypoints(img, vkeypoints,
397
+ imgWithKeypoints, cv::Scalar(0, 0, 255));
398
+
399
+ // Write the image with keypoints
400
+ std::string filename = "viz_image_orig_keypoints.jpg";
401
+ cv::imwrite(filename, imgWithKeypoints);
402
+ }
403
+
404
+ // Iterate over the descriptors
405
+ cv::Mat descriptors = jsonToMat<uchar>(jdescriptors);
406
+ kpts_results.append_keypoints(vkeypoints);
407
+ kpts_results.append_descriptors(descriptors);
408
+ }
409
+ return kpts_results;
410
+ }
api/types.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ImagesInput(BaseModel):
7
+ data: List[str] = []
8
+ max_keypoints: List[int] = []
9
+ timestamps: List[str] = []
10
+ grayscale: bool = False
11
+ image_hw: List[List[int]] = [[], []]
12
+ feature_type: int = 0
13
+ rotates: List[float] = []
14
+ scales: List[float] = []
15
+ reference_points: List[List[float]] = []
16
+ binarize: bool = False
requirements.txt CHANGED
@@ -2,8 +2,7 @@ e2cnn
2
  einops
3
  easydict
4
  gdown
5
- gradio==4.44.0
6
- gradio_client==1.3.0
7
  h5py
8
  huggingface_hub
9
  imageio
 
2
  einops
3
  easydict
4
  gdown
5
+ gradio==5.1.0
 
6
  h5py
7
  huggingface_hub
8
  imageio
test_app_cli.py CHANGED
@@ -1,12 +1,13 @@
 
 
 
1
  import cv2
 
2
  from hloc import logger
3
- from ui.utils import (
4
- get_matcher_zoo,
5
- load_config,
6
- DEVICE,
7
- ROOT,
8
- )
9
- from ui.api import ImageMatchingAPI
10
 
11
 
12
  def test_all(config: dict = None):
@@ -68,7 +69,7 @@ def test_one():
68
  "dense": False,
69
  }
70
  api = ImageMatchingAPI(conf=conf, device=DEVICE)
71
- pred = api(image0, image1)
72
  log_path = ROOT / "experiments" / "one"
73
  log_path.mkdir(exist_ok=True, parents=True)
74
  api.visualize(log_path=log_path)
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
  import cv2
5
+
6
  from hloc import logger
7
+ from ui.utils import DEVICE, ROOT, get_matcher_zoo, load_config
8
+
9
+ sys.path.append(str(Path(__file__).parents[1]))
10
+ from api.server import ImageMatchingAPI
 
 
 
11
 
12
 
13
  def test_all(config: dict = None):
 
69
  "dense": False,
70
  }
71
  api = ImageMatchingAPI(conf=conf, device=DEVICE)
72
+ api(image0, image1)
73
  log_path = ROOT / "experiments" / "one"
74
  log_path.mkdir(exist_ok=True, parents=True)
75
  api.visualize(log_path=log_path)
ui/__init__.py CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "1.0.1"
2
+
3
+
4
+ def get_version():
5
+ return __version__
ui/api.py DELETED
@@ -1,293 +0,0 @@
1
- import warnings
2
- from pathlib import Path
3
- from typing import Any, Dict, Optional
4
-
5
- import cv2
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import torch
9
-
10
- from hloc import extract_features, logger, match_dense, match_features
11
- from hloc.utils.viz import add_text, plot_keypoints
12
-
13
- from .utils import (
14
- ROOT,
15
- filter_matches,
16
- get_feature_model,
17
- get_model,
18
- load_config,
19
- )
20
- from .viz import display_matches, fig2im, plot_images
21
-
22
- warnings.simplefilter("ignore")
23
-
24
-
25
- class ImageMatchingAPI(torch.nn.Module):
26
- default_conf = {
27
- "ransac": {
28
- "enable": True,
29
- "estimator": "poselib",
30
- "geometry": "homography",
31
- "method": "RANSAC",
32
- "reproj_threshold": 3,
33
- "confidence": 0.9999,
34
- "max_iter": 10000,
35
- },
36
- }
37
-
38
- def __init__(
39
- self,
40
- conf: dict = {},
41
- device: str = "cpu",
42
- detect_threshold: float = 0.015,
43
- max_keypoints: int = 1024,
44
- match_threshold: float = 0.2,
45
- ) -> None:
46
- """
47
- Initializes an instance of the ImageMatchingAPI class.
48
-
49
- Args:
50
- conf (dict): A dictionary containing the configuration parameters.
51
- device (str, optional): The device to use for computation. Defaults to "cpu".
52
- detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
53
- max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
54
- match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
55
-
56
- Returns:
57
- None
58
- """
59
- super().__init__()
60
- self.device = device
61
- self.conf = {**self.default_conf, **conf}
62
- self._updata_config(detect_threshold, max_keypoints, match_threshold)
63
- self._init_models()
64
- if device == "cuda":
65
- memory_allocated = torch.cuda.memory_allocated(device)
66
- memory_reserved = torch.cuda.memory_reserved(device)
67
- logger.info(
68
- f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB"
69
- )
70
- logger.info(
71
- f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB"
72
- )
73
- self.pred = None
74
-
75
- def parse_match_config(self, conf):
76
- if conf["dense"]:
77
- return {
78
- **conf,
79
- "matcher": match_dense.confs.get(
80
- conf["matcher"]["model"]["name"]
81
- ),
82
- "dense": True,
83
- }
84
- else:
85
- return {
86
- **conf,
87
- "feature": extract_features.confs.get(
88
- conf["feature"]["model"]["name"]
89
- ),
90
- "matcher": match_features.confs.get(
91
- conf["matcher"]["model"]["name"]
92
- ),
93
- "dense": False,
94
- }
95
-
96
- def _updata_config(
97
- self,
98
- detect_threshold: float = 0.015,
99
- max_keypoints: int = 1024,
100
- match_threshold: float = 0.2,
101
- ):
102
- self.dense = self.conf["dense"]
103
- if self.conf["dense"]:
104
- try:
105
- self.conf["matcher"]["model"][
106
- "match_threshold"
107
- ] = match_threshold
108
- except TypeError as e:
109
- logger.error(e)
110
- else:
111
- self.conf["feature"]["model"]["max_keypoints"] = max_keypoints
112
- self.conf["feature"]["model"][
113
- "keypoint_threshold"
114
- ] = detect_threshold
115
- self.extract_conf = self.conf["feature"]
116
-
117
- self.match_conf = self.conf["matcher"]
118
-
119
- def _init_models(self):
120
- # initialize matcher
121
- self.matcher = get_model(self.match_conf)
122
- # initialize extractor
123
- if self.dense:
124
- self.extractor = None
125
- else:
126
- self.extractor = get_feature_model(self.conf["feature"])
127
-
128
- def _forward(self, img0, img1):
129
- if self.dense:
130
- pred = match_dense.match_images(
131
- self.matcher,
132
- img0,
133
- img1,
134
- self.match_conf["preprocessing"],
135
- device=self.device,
136
- )
137
- last_fixed = "{}".format( # noqa: F841
138
- self.match_conf["model"]["name"]
139
- )
140
- else:
141
- pred0 = extract_features.extract(
142
- self.extractor, img0, self.extract_conf["preprocessing"]
143
- )
144
- pred1 = extract_features.extract(
145
- self.extractor, img1, self.extract_conf["preprocessing"]
146
- )
147
- pred = match_features.match_images(self.matcher, pred0, pred1)
148
- return pred
149
-
150
- @torch.inference_mode()
151
- def forward(
152
- self,
153
- img0: np.ndarray,
154
- img1: np.ndarray,
155
- ) -> Dict[str, np.ndarray]:
156
- """
157
- Forward pass of the image matching API.
158
-
159
- Args:
160
- img0: A 3D NumPy array of shape (H, W, C) representing the first image.
161
- Values are in the range [0, 1] and are in RGB mode.
162
- img1: A 3D NumPy array of shape (H, W, C) representing the second image.
163
- Values are in the range [0, 1] and are in RGB mode.
164
-
165
- Returns:
166
- A dictionary containing the following keys:
167
- - image0_orig: The original image 0.
168
- - image1_orig: The original image 1.
169
- - keypoints0_orig: The keypoints detected in image 0.
170
- - keypoints1_orig: The keypoints detected in image 1.
171
- - mkeypoints0_orig: The raw matches between image 0 and image 1.
172
- - mkeypoints1_orig: The raw matches between image 1 and image 0.
173
- - mmkeypoints0_orig: The RANSAC inliers in image 0.
174
- - mmkeypoints1_orig: The RANSAC inliers in image 1.
175
- - mconf: The confidence scores for the raw matches.
176
- - mmconf: The confidence scores for the RANSAC inliers.
177
- """
178
- # Take as input a pair of images (not a batch)
179
- assert isinstance(img0, np.ndarray)
180
- assert isinstance(img1, np.ndarray)
181
- self.pred = self._forward(img0, img1)
182
- if self.conf["ransac"]["enable"]:
183
- self.pred = self._geometry_check(self.pred)
184
- return self.pred
185
-
186
- def _geometry_check(
187
- self,
188
- pred: Dict[str, Any],
189
- ) -> Dict[str, Any]:
190
- """
191
- Filter matches using RANSAC. If keypoints are available, filter by keypoints.
192
- If lines are available, filter by lines. If both keypoints and lines are
193
- available, filter by keypoints.
194
-
195
- Args:
196
- pred (Dict[str, Any]): dict of matches, including original keypoints.
197
- See :func:`filter_matches` for the expected keys.
198
-
199
- Returns:
200
- Dict[str, Any]: filtered matches
201
- """
202
- pred = filter_matches(
203
- pred,
204
- ransac_method=self.conf["ransac"]["method"],
205
- ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"],
206
- ransac_confidence=self.conf["ransac"]["confidence"],
207
- ransac_max_iter=self.conf["ransac"]["max_iter"],
208
- )
209
- return pred
210
-
211
- def visualize(
212
- self,
213
- log_path: Optional[Path] = None,
214
- ) -> None:
215
- """
216
- Visualize the matches.
217
-
218
- Args:
219
- log_path (Path, optional): The directory to save the images. Defaults to None.
220
-
221
- Returns:
222
- None
223
- """
224
- if self.conf["dense"]:
225
- postfix = str(self.conf["matcher"]["model"]["name"])
226
- else:
227
- postfix = "{}_{}".format(
228
- str(self.conf["feature"]["model"]["name"]),
229
- str(self.conf["matcher"]["model"]["name"]),
230
- )
231
- titles = [
232
- "Image 0 - Keypoints",
233
- "Image 1 - Keypoints",
234
- ]
235
- pred: Dict[str, Any] = self.pred
236
- image0: np.ndarray = pred["image0_orig"]
237
- image1: np.ndarray = pred["image1_orig"]
238
- output_keypoints: np.ndarray = plot_images(
239
- [image0, image1], titles=titles, dpi=300
240
- )
241
- if (
242
- "keypoints0_orig" in pred.keys()
243
- and "keypoints1_orig" in pred.keys()
244
- ):
245
- plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]])
246
- text: str = (
247
- f"# keypoints0: {len(pred['keypoints0_orig'])} \n"
248
- + f"# keypoints1: {len(pred['keypoints1_orig'])}"
249
- )
250
- add_text(0, text, fs=15)
251
- output_keypoints = fig2im(output_keypoints)
252
- # plot images with raw matches
253
- titles = [
254
- "Image 0 - Raw matched keypoints",
255
- "Image 1 - Raw matched keypoints",
256
- ]
257
- output_matches_raw, num_matches_raw = display_matches(
258
- pred, titles=titles, tag="KPTS_RAW"
259
- )
260
- # plot images with ransac matches
261
- titles = [
262
- "Image 0 - Ransac matched keypoints",
263
- "Image 1 - Ransac matched keypoints",
264
- ]
265
- output_matches_ransac, num_matches_ransac = display_matches(
266
- pred, titles=titles, tag="KPTS_RANSAC"
267
- )
268
- if log_path is not None:
269
- img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png"
270
- img_matches_raw_path: Path = (
271
- log_path / f"img_matches_raw_{postfix}.png"
272
- )
273
- img_matches_ransac_path: Path = (
274
- log_path / f"img_matches_ransac_{postfix}.png"
275
- )
276
- cv2.imwrite(
277
- str(img_keypoints_path),
278
- output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
279
- )
280
- cv2.imwrite(
281
- str(img_matches_raw_path),
282
- output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR
283
- )
284
- cv2.imwrite(
285
- str(img_matches_ransac_path),
286
- output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR
287
- )
288
- plt.close("all")
289
-
290
-
291
- if __name__ == "__main__":
292
- config = load_config(ROOT / "ui/config.yaml")
293
- api = ImageMatchingAPI(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ui/app_class.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
  from typing import Any, Dict, Optional, Tuple
3
 
@@ -6,7 +7,8 @@ import numpy as np
6
  from easydict import EasyDict as edict
7
  from omegaconf import OmegaConf
8
 
9
- from hloc import flush_logs, read_logs
 
10
  from ui.sfm import SfmEngine
11
  from ui.utils import (
12
  GRADIO_VERSION,
@@ -272,24 +274,6 @@ class ImageMatchingApp:
272
  self.display_supported_algorithms()
273
 
274
  with gr.Column():
275
- with gr.Accordion("Open for More: Logs", open=False):
276
- logs = gr.Textbox(
277
- placeholder="\n" * 10,
278
- label="Logs",
279
- info="Verbose from inference will be displayed below.",
280
- lines=10,
281
- max_lines=10,
282
- autoscroll=True,
283
- elem_id="logs",
284
- show_copy_button=True,
285
- container=True,
286
- elem_classes="logs_class",
287
- )
288
- self.app.load(read_logs, None, logs, every=1)
289
- btn_clear_logs = gr.Button(
290
- "Clear logs", elem_id="logs-button"
291
- )
292
- btn_clear_logs.click(flush_logs, [], [])
293
 
294
  with gr.Accordion(
295
  "Open for More: Keypoints", open=True
@@ -523,7 +507,7 @@ class ImageMatchingApp:
523
  key: str = list(self.matcher_zoo.keys())[
524
  0
525
  ] # Get the first key from matcher_zoo
526
- flush_logs()
527
  return (
528
  None, # image0: Optional[np.ndarray]
529
  None, # image1: Optional[np.ndarray]
 
1
+ import sys
2
  from pathlib import Path
3
  from typing import Any, Dict, Optional, Tuple
4
 
 
7
  from easydict import EasyDict as edict
8
  from omegaconf import OmegaConf
9
 
10
+ sys.path.append(str(Path(__file__).parents[1]))
11
+
12
  from ui.sfm import SfmEngine
13
  from ui.utils import (
14
  GRADIO_VERSION,
 
274
  self.display_supported_algorithms()
275
 
276
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  with gr.Accordion(
279
  "Open for More: Keypoints", open=True
 
507
  key: str = list(self.matcher_zoo.keys())[
508
  0
509
  ] # Get the first key from matcher_zoo
510
+ # flush_logs()
511
  return (
512
  None, # image0: Optional[np.ndarray]
513
  None, # image1: Optional[np.ndarray]
ui/config.yaml CHANGED
@@ -41,6 +41,7 @@ matcher_zoo:
41
  DUSt3R:
42
  # TODO: duster is under development
43
  enable: true
 
44
  matcher: duster
45
  dense: true
46
  info:
@@ -52,6 +53,7 @@ matcher_zoo:
52
  display: true
53
  GIM(dkm):
54
  enable: true
 
55
  matcher: gim(dkm)
56
  dense: true
57
  info:
@@ -63,6 +65,7 @@ matcher_zoo:
63
  display: true
64
  RoMa:
65
  matcher: roma
 
66
  dense: true
67
  info:
68
  name: RoMa #dispaly name
@@ -73,6 +76,7 @@ matcher_zoo:
73
  display: true
74
  dkm:
75
  matcher: dkm
 
76
  dense: true
77
  info:
78
  name: DKM #dispaly name
@@ -398,9 +402,9 @@ matcher_zoo:
398
  display: true
399
 
400
  sfd2+imp:
 
401
  matcher: imp
402
  feature: sfd2
403
- enable: true
404
  dense: false
405
  info:
406
  name: SFD2+IMP #dispaly name
@@ -411,9 +415,9 @@ matcher_zoo:
411
  display: true
412
 
413
  sfd2+mnn:
 
414
  matcher: NN-mutual
415
  feature: sfd2
416
- enable: true
417
  dense: false
418
  info:
419
  name: SFD2+MNN #dispaly name
 
41
  DUSt3R:
42
  # TODO: duster is under development
43
  enable: true
44
+ # skip_ci: true
45
  matcher: duster
46
  dense: true
47
  info:
 
53
  display: true
54
  GIM(dkm):
55
  enable: true
56
+ # skip_ci: true
57
  matcher: gim(dkm)
58
  dense: true
59
  info:
 
65
  display: true
66
  RoMa:
67
  matcher: roma
68
+ skip_ci: true
69
  dense: true
70
  info:
71
  name: RoMa #dispaly name
 
76
  display: true
77
  dkm:
78
  matcher: dkm
79
+ skip_ci: true
80
  dense: true
81
  info:
82
  name: DKM #dispaly name
 
402
  display: true
403
 
404
  sfd2+imp:
405
+ enable: true
406
  matcher: imp
407
  feature: sfd2
 
408
  dense: false
409
  info:
410
  name: SFD2+IMP #dispaly name
 
415
  display: true
416
 
417
  sfd2+mnn:
418
+ enable: true
419
  matcher: NN-mutual
420
  feature: sfd2
 
421
  dense: false
422
  info:
423
  name: SFD2+MNN #dispaly name
ui/sfm.py CHANGED
@@ -1,9 +1,10 @@
1
  import shutil
 
2
  import tempfile
3
  from pathlib import Path
4
  from typing import Any, Dict, List
5
 
6
- import pycolmap
7
 
8
  from hloc import (
9
  extract_features,
@@ -14,7 +15,12 @@ from hloc import (
14
  visualization,
15
  )
16
 
17
- from .viz import fig2im
 
 
 
 
 
18
 
19
 
20
  class SfmEngine:
 
1
  import shutil
2
+ import sys
3
  import tempfile
4
  from pathlib import Path
5
  from typing import Any, Dict, List
6
 
7
+ sys.path.append(str(Path(__file__).parents[1]))
8
 
9
  from hloc import (
10
  extract_features,
 
15
  visualization,
16
  )
17
 
18
+ try:
19
+ import pycolmap
20
+ except ImportError:
21
+ logger.warning("pycolmap not installed, some features may not work")
22
+
23
+ from ui.viz import fig2im
24
 
25
 
26
  class SfmEngine:
ui/utils.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import pickle
3
  import random
4
  import shutil
 
5
  import time
6
  import warnings
7
  from itertools import combinations
@@ -16,6 +17,8 @@ import poselib
16
  import psutil
17
  from PIL import Image
18
 
 
 
19
  from hloc import (
20
  DEVICE,
21
  extract_features,
@@ -26,8 +29,7 @@ from hloc import (
26
  matchers,
27
  )
28
  from hloc.utils.base_model import dynamic_load
29
-
30
- from .viz import display_keypoints, display_matches, fig2im, plot_images
31
 
32
  warnings.simplefilter("ignore")
33
 
 
2
  import pickle
3
  import random
4
  import shutil
5
+ import sys
6
  import time
7
  import warnings
8
  from itertools import combinations
 
17
  import psutil
18
  from PIL import Image
19
 
20
+ sys.path.append(str(Path(__file__).parents[1]))
21
+
22
  from hloc import (
23
  DEVICE,
24
  extract_features,
 
29
  matchers,
30
  )
31
  from hloc.utils.base_model import dynamic_load
32
+ from ui.viz import display_keypoints, display_matches, fig2im, plot_images
 
33
 
34
  warnings.simplefilter("ignore")
35
 
ui/viz.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import typing
2
  from pathlib import Path
3
  from typing import Dict, List, Optional, Tuple, Union
@@ -8,6 +9,8 @@ import matplotlib.pyplot as plt
8
  import numpy as np
9
  import seaborn as sns
10
 
 
 
11
  from hloc.utils.viz import add_text, plot_keypoints
12
 
13
  np.random.seed(1995)
 
1
+ import sys
2
  import typing
3
  from pathlib import Path
4
  from typing import Dict, List, Optional, Tuple, Union
 
9
  import numpy as np
10
  import seaborn as sns
11
 
12
+ sys.path.append(str(Path(__file__).parents[1]))
13
+
14
  from hloc.utils.viz import add_text, plot_keypoints
15
 
16
  np.random.seed(1995)