|
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from utils import configs
|
|
from utils.functional import (
|
|
check_data_type_variable,
|
|
euclidean_distance_normalized,
|
|
get_device,
|
|
image_augmentations,
|
|
)
|
|
|
|
|
|
class BaseModelImageSimilarity(ABC):
|
|
def __init__(
|
|
self,
|
|
name_model: str,
|
|
freeze_model: bool,
|
|
pretrained_model: bool,
|
|
support_set_method: str,
|
|
):
|
|
self.name_model = name_model
|
|
self.freeze_model = freeze_model
|
|
self.pretrained_model = pretrained_model
|
|
self.support_set_method = support_set_method
|
|
self.model = None
|
|
self.device = get_device()
|
|
|
|
self.check_arguments()
|
|
|
|
def check_arguments(self):
|
|
check_data_type_variable(self.name_model, str)
|
|
check_data_type_variable(self.freeze_model, bool)
|
|
check_data_type_variable(self.pretrained_model, bool)
|
|
check_data_type_variable(self.support_set_method, str)
|
|
|
|
old_name_model = self.name_model
|
|
if self.name_model == configs.CLIP_NAME_MODEL:
|
|
old_name_model = self.name_model
|
|
self.name_model = "clip"
|
|
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
|
raise ValueError(f"Model {self.name_model} not supported")
|
|
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
|
raise ValueError(
|
|
f"Support set method {self.support_set_method} not supported"
|
|
)
|
|
self.name_model = old_name_model
|
|
|
|
@abstractmethod
|
|
def init_model(self):
|
|
pass
|
|
|
|
def get_similarity(self, image1: np.ndarray, image2: np.ndarray) -> dict:
|
|
image1_input = image_augmentations()(image=image1)["image"]
|
|
image2_input = image_augmentations()(image=image2)["image"]
|
|
|
|
image1_input = image1_input.unsqueeze(axis=0).to(self.device)
|
|
image2_input = image2_input.unsqueeze(axis=0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
start_time = time.perf_counter()
|
|
image1_input = self.model(image1_input)
|
|
image2_input = self.model(image2_input)
|
|
end_time = time.perf_counter() - start_time
|
|
|
|
image1_input = image1_input.detach().cpu().numpy()
|
|
image2_input = image2_input.detach().cpu().numpy()
|
|
similarity = euclidean_distance_normalized(image1_input, image2_input)
|
|
result_similarity = (
|
|
"same image"
|
|
if similarity
|
|
> configs.NAME_MODELS[self.name_model]["image_similarity_threshold"]
|
|
else "not same image"
|
|
)
|
|
return {
|
|
"similarity": similarity,
|
|
"result_similarity": result_similarity,
|
|
"inference_time": end_time,
|
|
}
|
|
|