import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import numpy as np from PIL import Image from models.base_model import BaseModelImageEmbeddings from utils import configs from .backbone_model import CLIPModel, TorchModel class ImageEmbeddings(BaseModelImageEmbeddings): def __init__( self, name_model: str, freeze_model: bool, pretrained_model: bool, support_set_method: str, ): super().__init__(name_model, freeze_model, pretrained_model, support_set_method) self.init_model() def init_model(self): if self.name_model == "clip": self.model = CLIPModel( configs.CLIP_NAME_MODEL, self.freeze_model, self.pretrained_model ) else: self.model = TorchModel( self.name_model, self.freeze_model, self.pretrained_model ) self.model.to(self.device) self.model.eval() if __name__ == "__main__": model = ImageEmbeddings("mobilenetv3_large_100", True, True, "5_shot") image1 = np.array( Image.open( "../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png" ).convert("RGB") ) print(model.get_embeddings(image1))