hafidhsoekma's picture
First commit
49bceed
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import time
from abc import ABC, abstractmethod
import numpy as np
import torch
from utils import configs
from utils.functional import check_data_type_variable, get_device, image_augmentations
class BaseModelImageEmbeddings(ABC):
def __init__(
self,
name_model: str,
freeze_model: bool,
pretrained_model: bool,
support_set_method: str,
):
self.name_model = name_model
self.freeze_model = freeze_model
self.pretrained_model = pretrained_model
self.support_set_method = support_set_method
self.model = None
self.device = get_device()
self.check_arguments()
def check_arguments(self):
check_data_type_variable(self.name_model, str)
check_data_type_variable(self.freeze_model, bool)
check_data_type_variable(self.pretrained_model, bool)
check_data_type_variable(self.support_set_method, str)
old_name_model = self.name_model
if self.name_model == configs.CLIP_NAME_MODEL:
old_name_model = self.name_model
self.name_model = "clip"
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
raise ValueError(f"Model {self.name_model} not supported")
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
raise ValueError(
f"Support set method {self.support_set_method} not supported"
)
self.name_model = old_name_model
@abstractmethod
def init_model(self):
pass
def get_embeddings(self, image: np.ndarray) -> dict:
image_input = image_augmentations()(image=image)["image"]
image_input = image_input.unsqueeze(axis=0).to(self.device)
with torch.no_grad():
start_time = time.perf_counter()
embeddings = self.model(image_input)
end_time = time.perf_counter() - start_time
embeddings = embeddings.detach().cpu().numpy()
return {
"embeddings": embeddings,
"inference_time": end_time,
}