import gradio as gr from IJEPA_finetune import ViTIJEPA import torch from einops import rearrange from torchvision.transforms import Compose import torchvision class CMYKToRGB(object): def __call__(self, img): # Ensure that the input is a PIL Image if not isinstance(img, torch.Tensor): raise TypeError("Input image should be a torch.Tensor") # Ensure that the image has 4 channels (CMYK) if img.shape[-3] != 4: return img # Extract CMYK channels c, m, y, k = img.unbind(-3) # CMYK to RGB transformation r = 255 * (1 - c) * (1 - k) g = 255 * (1 - m) * (1 - k) b = 255 * (1 - y) * (1 - k) # Stack RGB channels rgb_img = torch.stack([r, g, b], dim=-3) return rgb_img classes = ['Acanthostichus', 'Aenictus', 'Amblyopone', 'Attini', 'Bothriomyrmecini', 'Camponotini', 'Cerapachys', 'Cheliomyrmex', 'Crematogastrini', 'Cylindromyrmex', 'Dolichoderini', 'Dorylus', 'Eciton', 'Ectatommini', 'Formicini', 'Fulakora', 'Gesomyrmecini', 'Gigantiopini', 'Heteroponerini', 'Labidus', 'Lasiini', 'Leptomyrmecini', 'Lioponera', 'Melophorini', 'Myopopone', 'Myrmecia', 'Myrmelachistini', 'Myrmicini', 'Myrmoteratini', 'Mystrium', 'Neivamyrmex', 'Nomamyrmex', 'Oecophyllini', 'Ooceraea', 'Paraponera', 'Parasyscia', 'Plagiolepidini', 'Platythyreini', 'Pogonomyrmecini', 'Ponerini', 'Prionopelta', 'Probolomyrmecini', 'Proceratiini', 'Pseudomyrmex', 'Solenopsidini', 'Stenammini', 'Stigmatomma', 'Syscia', 'Tapinomini', 'Tetraponera', 'Zasphinctus'] class_to_idx = {idx: cls for idx, cls in enumerate(classes)} train_transforms = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Resize((64, 64), antialias=True), CMYKToRGB(), ] ) model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) def ant_genus_classification(image): image = train_transforms(image).unsqueeze(0) print(image.shape) with torch.no_grad(): y_hat = model(image) preds = torch.nn.functional.softmax(y_hat, dim=1) # preds = torch.argmax(preds, dim=1) print(preds.shape) confidences = {class_to_idx[i]: float(preds[0][i]) for i in range(len(classes))} return confidences # prediction = model(image)[0] # prediction = prediction.tolist() # print(prediction) # return { # class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01 # } demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3)) if __name__ == "__main__": demo.launch(debug=True)