machine-learning / superpoint.py
abc-valera's picture
Added application file
0156fe0
from transformers import (
SuperPointImageProcessor,
SuperPointForKeypointDetection as SuperPointKeypointDetection,
)
import torch
import cv2
import numpy as np
import os
### ЗАВАНТАЖЕННЯ І ПОПЕРЕДНЯ ОБРОБКА ДАНИХ ###
# Створюємо директорію для вихідних файлів
os.makedirs("output", exist_ok=True)
images = ["data/image0.jpg", "data/image1.jpg", "data/image2.jpg", "data/image3.jpg"]
original_images = []
resized_images = []
for image_path in images:
image = cv2.imread(image_path)
original_images.append(image)
resized_image = cv2.resize(image, (640, 480))
resized_images.append(resized_image)
### ЗАСТОСУВАННЯ ДЕТЕКТОРА ОЗНАК SUPERPOINT ###
processor = SuperPointImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointKeypointDetection.from_pretrained("magic-leap-community/superpoint")
inputs = processor(resized_images, return_tensors="pt")
outputs = model(**inputs)
### ВІЗУАЛІЗАЦІЯ РЕЗУЛЬТАТІВ ###
def draw_keypoints(image, keypoints, color=(0, 255, 0), radius=2):
for kp in keypoints:
x, y = int(kp[0]), int(kp[1])
cv2.circle(image, (x, y), radius, color, -1)
return image
def create_blank_image(shape):
return np.zeros((shape[0], shape[1], 3), dtype=np.uint8)
all_keypoints = []
for i, (original_image, resized_image) in enumerate(
zip(original_images, resized_images)
):
image_mask = outputs.mask[i]
image_indices = torch.nonzero(image_mask).squeeze()
image_keypoints = outputs.keypoints[i][image_indices]
# Масштабуємо ключові точки назад до оригінального розміру
scale_x = original_image.shape[1] / resized_image.shape[1]
scale_y = original_image.shape[0] / resized_image.shape[0]
scaled_keypoints = image_keypoints.clone()
scaled_keypoints[:, 0] *= scale_x
scaled_keypoints[:, 1] *= scale_y
all_keypoints.append(scaled_keypoints)
# Створюємо зображення з ключовими точками
keypoints_image = draw_keypoints(original_image.copy(), scaled_keypoints)
cv2.imwrite(f"output/image{i}.png", keypoints_image)
# Створюємо зображення тільки з ключовими точками
blank_image = create_blank_image(original_image.shape[:2])
just_keypoints_image = draw_keypoints(blank_image, scaled_keypoints)
cv2.imwrite(f"output/image{i}_just_keypoints.png", just_keypoints_image)
### СПІВСТАВЛЕННЯ ОЗНАК ###
def match_keypoints(img1, kp1, img2, kp2, method="flann"):
# Convert keypoints to cv2.KeyPoint objects
kp1 = [cv2.KeyPoint(x=float(kp[0]), y=float(kp[1]), size=1) for kp in kp1]
kp2 = [cv2.KeyPoint(x=float(kp[0]), y=float(kp[1]), size=1) for kp in kp2]
# Compute descriptors
sift = cv2.SIFT_create()
_, des1 = sift.compute(img1, kp1)
_, des2 = sift.compute(img2, kp2)
if method == "flann":
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(des1, des2, k=2)
else: # BF Matcher
bf = cv2.BFMatcher()
matches = bf.knnMatch(des1, des2, k=2)
# Apply ratio test
good_matches = []
for m, n in matches:
if m.distance < 0.7 * n.distance:
good_matches.append(m)
# Draw matches
img_matches = cv2.drawMatches(
img1,
kp1,
img2,
kp2,
good_matches,
None,
flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
)
return img_matches
# Застосовуємо алгоритми співставлення ознак, якщо є більше одного зображення
if len(images) > 1:
for i in range(1, len(images)):
# FLANN matching
flann_matches = match_keypoints(
original_images[0],
all_keypoints[0],
original_images[i],
all_keypoints[i],
method="flann",
)
cv2.imwrite(f"output/image0_image{i}_flann.png", flann_matches)
# BF matching
bf_matches = match_keypoints(
original_images[0],
all_keypoints[0],
original_images[i],
all_keypoints[i],
method="bf",
)
cv2.imwrite(f"output/image0_image{i}_bf.png", bf_matches)