import albumentations as A from transformers import PreTrainedModel from PIL import Image import numpy as np import torch import cv2 from configuration_cetacean_classifier import CetaceanClassifierConfig from train import SphereClassifier WHALE_CLASSES = np.array( [ "beluga", "blue_whale", "bottlenose_dolphin", "brydes_whale", "commersons_dolphin", "common_dolphin", "cuviers_beaked_whale", "dusky_dolphin", "false_killer_whale", "fin_whale", "frasiers_dolphin", "gray_whale", "humpback_whale", "killer_whale", "long_finned_pilot_whale", "melon_headed_whale", "minke_whale", "pantropic_spotted_dolphin", "pygmy_killer_whale", "rough_toothed_dolphin", "sei_whale", "short_finned_pilot_whale", "southern_right_whale", "spinner_dolphin", "spotted_dolphin", "white_sided_dolphin", ] ) class CetaceanClassifierModelForImageClassification(PreTrainedModel): config_class = CetaceanClassifierConfig def __init__(self, config): super().__init__(config) self.model = SphereClassifier(cfg=config.to_dict()) # load_from_checkpoint("cetacean_classifier/last.ckpt") # self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt") self.model.eval() self.config = config self.transforms = self.make_transforms(data_aug=True) def make_transforms(self, data_aug: bool): augments = [] if data_aug: aug = self.config.aug augments = [ A.RandomResizedCrop( self.config.image_size[0], self.config.image_size[1], scale=(aug["crop_scale"], 1.0), ratio=(aug["crop_l"], aug["crop_r"]), ),] return A.Compose(augments) def preprocess_image(self, img) -> torch.Tensor: rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC) image = self.transforms(image=image)["image"] return torch.Tensor(image).transpose(2, 0).unsqueeze(0) #image_resized = img.resize((480, 480)) #image_resized = np.array(image_resized)[None] #image_resized = np.transpose(image_resized, [0, 3, 2, 1]) #image_tensor = torch.Tensor(image_resized) #return image_tensor def forward(self, img, labels=None): tensor = self.preprocess_image(img) head_id_logits, head_species_logits = self.model(tensor) head_species_logits = head_species_logits.detach().numpy() sorted_idx = head_species_logits.argsort()[0] sorted_idx = np.array(list(reversed(sorted_idx))) top_three_logits = sorted_idx[:3] top_three_whale_preds = WHALE_CLASSES[top_three_logits] return {"predictions": top_three_whale_preds}