from transformers import (
    SuperPointImageProcessor,
    SuperPointForKeypointDetection as SuperPointKeypointDetection,
)
import torch
import cv2
import numpy as np
import os


### ЗАВАНТАЖЕННЯ І ПОПЕРЕДНЯ ОБРОБКА ДАНИХ ###

# Створюємо директорію для вихідних файлів
os.makedirs("output", exist_ok=True)

images = ["data/image0.jpg", "data/image1.jpg", "data/image2.jpg", "data/image3.jpg"]
original_images = []
resized_images = []

for image_path in images:
    image = cv2.imread(image_path)
    original_images.append(image)
    resized_image = cv2.resize(image, (640, 480))
    resized_images.append(resized_image)


### ЗАСТОСУВАННЯ ДЕТЕКТОРА ОЗНАК SUPERPOINT ###


processor = SuperPointImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointKeypointDetection.from_pretrained("magic-leap-community/superpoint")

inputs = processor(resized_images, return_tensors="pt")
outputs = model(**inputs)


### ВІЗУАЛІЗАЦІЯ РЕЗУЛЬТАТІВ ###


def draw_keypoints(image, keypoints, color=(0, 255, 0), radius=2):
    for kp in keypoints:
        x, y = int(kp[0]), int(kp[1])
        cv2.circle(image, (x, y), radius, color, -1)
    return image


def create_blank_image(shape):
    return np.zeros((shape[0], shape[1], 3), dtype=np.uint8)


all_keypoints = []

for i, (original_image, resized_image) in enumerate(
    zip(original_images, resized_images)
):
    image_mask = outputs.mask[i]
    image_indices = torch.nonzero(image_mask).squeeze()
    image_keypoints = outputs.keypoints[i][image_indices]

    # Масштабуємо ключові точки назад до оригінального розміру
    scale_x = original_image.shape[1] / resized_image.shape[1]
    scale_y = original_image.shape[0] / resized_image.shape[0]
    scaled_keypoints = image_keypoints.clone()
    scaled_keypoints[:, 0] *= scale_x
    scaled_keypoints[:, 1] *= scale_y

    all_keypoints.append(scaled_keypoints)

    # Створюємо зображення з ключовими точками
    keypoints_image = draw_keypoints(original_image.copy(), scaled_keypoints)
    cv2.imwrite(f"output/image{i}.png", keypoints_image)

    # Створюємо зображення тільки з ключовими точками
    blank_image = create_blank_image(original_image.shape[:2])
    just_keypoints_image = draw_keypoints(blank_image, scaled_keypoints)
    cv2.imwrite(f"output/image{i}_just_keypoints.png", just_keypoints_image)


### СПІВСТАВЛЕННЯ ОЗНАК ###


def match_keypoints(img1, kp1, img2, kp2, method="flann"):
    # Convert keypoints to cv2.KeyPoint objects
    kp1 = [cv2.KeyPoint(x=float(kp[0]), y=float(kp[1]), size=1) for kp in kp1]
    kp2 = [cv2.KeyPoint(x=float(kp[0]), y=float(kp[1]), size=1) for kp in kp2]

    # Compute descriptors
    sift = cv2.SIFT_create()
    _, des1 = sift.compute(img1, kp1)
    _, des2 = sift.compute(img2, kp2)

    if method == "flann":
        FLANN_INDEX_KDTREE = 1
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=50)
        flann = cv2.FlannBasedMatcher(index_params, search_params)
        matches = flann.knnMatch(des1, des2, k=2)
    else:  # BF Matcher
        bf = cv2.BFMatcher()
        matches = bf.knnMatch(des1, des2, k=2)

    # Apply ratio test
    good_matches = []
    for m, n in matches:
        if m.distance < 0.7 * n.distance:
            good_matches.append(m)

    # Draw matches
    img_matches = cv2.drawMatches(
        img1,
        kp1,
        img2,
        kp2,
        good_matches,
        None,
        flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
    )
    return img_matches


# Застосовуємо алгоритми співставлення ознак, якщо є більше одного зображення
if len(images) > 1:
    for i in range(1, len(images)):
        # FLANN matching
        flann_matches = match_keypoints(
            original_images[0],
            all_keypoints[0],
            original_images[i],
            all_keypoints[i],
            method="flann",
        )
        cv2.imwrite(f"output/image0_image{i}_flann.png", flann_matches)

        # BF matching
        bf_matches = match_keypoints(
            original_images[0],
            all_keypoints[0],
            original_images[i],
            all_keypoints[i],
            method="bf",
        )
        cv2.imwrite(f"output/image0_image{i}_bf.png", bf_matches)