|
import argparse |
|
import base64 |
|
import os |
|
import pickle |
|
import time |
|
from typing import Dict, List |
|
|
|
import cv2 |
|
import numpy as np |
|
import requests |
|
|
|
ENDPOINT = "http://127.0.0.1:8001" |
|
if "REMOTE_URL_RAILWAY" in os.environ: |
|
ENDPOINT = os.environ["REMOTE_URL_RAILWAY"] |
|
|
|
print(f"API ENDPOINT: {ENDPOINT}") |
|
|
|
API_VERSION = f"{ENDPOINT}/version" |
|
API_URL_MATCH = f"{ENDPOINT}/v1/match" |
|
API_URL_EXTRACT = f"{ENDPOINT}/v1/extract" |
|
|
|
|
|
def read_image(path: str) -> str: |
|
""" |
|
Read an image from a file, encode it as a JPEG and then as a base64 string. |
|
|
|
Args: |
|
path (str): The path to the image to read. |
|
|
|
Returns: |
|
str: The base64 encoded image. |
|
""" |
|
|
|
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
|
|
|
|
|
retval, buffer = cv2.imencode(".png", img) |
|
|
|
|
|
b64img = base64.b64encode(buffer).decode("utf-8") |
|
|
|
return b64img |
|
|
|
|
|
def do_api_requests(url=API_URL_EXTRACT, **kwargs): |
|
""" |
|
Helper function to send an API request to the image matching service. |
|
|
|
Args: |
|
url (str): The URL of the API endpoint to use. Defaults to the |
|
feature extraction endpoint. |
|
**kwargs: Additional keyword arguments to pass to the API. |
|
|
|
Returns: |
|
List[Dict[str, np.ndarray]]: A list of dictionaries containing the |
|
extracted features. The keys are "keypoints", "descriptors", and |
|
"scores", and the values are ndarrays of shape (N, 2), (N, ?), |
|
and (N,), respectively. |
|
""" |
|
|
|
reqbody = { |
|
|
|
"data": [], |
|
|
|
"max_keypoints": [100, 100], |
|
|
|
"timestamps": ["0", "1"], |
|
|
|
"grayscale": 0, |
|
|
|
"image_hw": [[640, 480], [320, 240]], |
|
|
|
"feature_type": 0, |
|
|
|
"rotates": [0.0, 0.0], |
|
|
|
"scales": [1.0, 1.0], |
|
|
|
"reference_points": [[640, 480], [320, 240]], |
|
|
|
"binarize": True, |
|
} |
|
|
|
reqbody.update(kwargs) |
|
try: |
|
|
|
r = requests.post(url, json=reqbody) |
|
if r.status_code == 200: |
|
|
|
return r.json() |
|
else: |
|
|
|
print(f"Error: Response code {r.status_code} - {r.text}") |
|
except Exception as e: |
|
|
|
print(f"An error occurred: {e}") |
|
|
|
|
|
def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]: |
|
""" |
|
Send a request to the API to generate a match between two images. |
|
|
|
Args: |
|
path0 (str): The path to the first image. |
|
path1 (str): The path to the second image. |
|
|
|
Returns: |
|
Dict[str, np.ndarray]: A dictionary containing the generated matches. |
|
The keys are "keypoints0", "keypoints1", "matches0", and "matches1", |
|
and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and |
|
(N, 2), respectively. |
|
""" |
|
files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")} |
|
try: |
|
|
|
response = requests.post(API_URL_MATCH, files=files) |
|
pred = {} |
|
if response.status_code == 200: |
|
pred = response.json() |
|
for key in list(pred.keys()): |
|
pred[key] = np.array(pred[key]) |
|
else: |
|
print( |
|
f"Error: Response code {response.status_code} - {response.text}" |
|
) |
|
finally: |
|
files["image0"].close() |
|
files["image1"].close() |
|
return pred |
|
|
|
|
|
def send_request_extract( |
|
input_images: str, viz: bool = False |
|
) -> List[Dict[str, np.ndarray]]: |
|
""" |
|
Send a request to the API to extract features from an image. |
|
|
|
Args: |
|
input_images (str): The path to the image. |
|
|
|
Returns: |
|
List[Dict[str, np.ndarray]]: A list of dictionaries containing the |
|
extracted features. The keys are "keypoints", "descriptors", and |
|
"scores", and the values are ndarrays of shape (N, 2), (N, 128), |
|
and (N,), respectively. |
|
""" |
|
image_data = read_image(input_images) |
|
inputs = { |
|
"data": [image_data], |
|
} |
|
response = do_api_requests( |
|
url=API_URL_EXTRACT, |
|
**inputs, |
|
) |
|
print("Keypoints detected: {}".format(len(response[0]["keypoints"]))) |
|
|
|
|
|
if viz: |
|
from hloc.utils.viz import plot_keypoints |
|
from ui.viz import fig2im, plot_images |
|
|
|
kpts = np.array(response[0]["keypoints_orig"]) |
|
if "image_orig" in response[0].keys(): |
|
img_orig = np.array(["image_orig"]) |
|
|
|
output_keypoints = plot_images([img_orig], titles="titles", dpi=300) |
|
plot_keypoints([kpts]) |
|
output_keypoints = fig2im(output_keypoints) |
|
cv2.imwrite( |
|
"demo_match.jpg", |
|
output_keypoints[:, :, ::-1].copy(), |
|
) |
|
return response |
|
|
|
|
|
def get_api_version(): |
|
try: |
|
response = requests.get(API_VERSION).json() |
|
print("API VERSION: {}".format(response["version"])) |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Send text to stable audio server and receive generated audio." |
|
) |
|
parser.add_argument( |
|
"--image0", |
|
required=False, |
|
help="Path for the file's melody", |
|
default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg", |
|
) |
|
parser.add_argument( |
|
"--image1", |
|
required=False, |
|
help="Path for the file's melody", |
|
default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg", |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
get_api_version() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(10): |
|
t1 = time.time() |
|
preds = send_request_extract(args.image0) |
|
t2 = time.time() |
|
print(f"Time cost2: {(t2 - t1)} seconds") |
|
|
|
|
|
with open("preds.pkl", "wb") as f: |
|
pickle.dump(preds, f) |
|
|