File size: 3,318 Bytes
88ae77c 5f73c99 88ae77c b44e693 393b535 b44e693 88ae77c b44e693 88ae77c 3798706 b687777 9a82fe2 b44e693 be1fa74 88ae77c adbabdb 88ae77c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from IJEPA_finetune import ViTIJEPA
import torch
from einops import rearrange
from torchvision.transforms import Compose
import torchvision
class CMYKToRGB(object):
def __call__(self, img):
# Ensure that the input is a PIL Image
if not isinstance(img, torch.Tensor):
raise TypeError("Input image should be a torch.Tensor")
# Ensure that the image has 4 channels (CMYK)
if img.shape[-3] != 4:
return img
# Extract CMYK channels
c, m, y, k = img.unbind(-3)
# CMYK to RGB transformation
r = 255 * (1 - c) * (1 - k)
g = 255 * (1 - m) * (1 - k)
b = 255 * (1 - y) * (1 - k)
# Stack RGB channels
rgb_img = torch.stack([r, g, b], dim=-3)
return rgb_img
classes = ['Acanthostichus',
'Aenictus',
'Amblyopone',
'Attini',
'Bothriomyrmecini',
'Camponotini',
'Cerapachys',
'Cheliomyrmex',
'Crematogastrini',
'Cylindromyrmex',
'Dolichoderini',
'Dorylus',
'Eciton',
'Ectatommini',
'Formicini',
'Fulakora',
'Gesomyrmecini',
'Gigantiopini',
'Heteroponerini',
'Labidus',
'Lasiini',
'Leptomyrmecini',
'Lioponera',
'Melophorini',
'Myopopone',
'Myrmecia',
'Myrmelachistini',
'Myrmicini',
'Myrmoteratini',
'Mystrium',
'Neivamyrmex',
'Nomamyrmex',
'Oecophyllini',
'Ooceraea',
'Paraponera',
'Parasyscia',
'Plagiolepidini',
'Platythyreini',
'Pogonomyrmecini',
'Ponerini',
'Prionopelta',
'Probolomyrmecini',
'Proceratiini',
'Pseudomyrmex',
'Solenopsidini',
'Stenammini',
'Stigmatomma',
'Syscia',
'Tapinomini',
'Tetraponera',
'Zasphinctus']
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}
train_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((64, 64), antialias=True),
CMYKToRGB(),
]
)
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))
def ant_genus_classification(image):
image = train_transforms(image).unsqueeze(0)
print(image.shape)
with torch.no_grad():
y_hat = model(image)
preds = torch.nn.functional.softmax(y_hat, dim=1)
# preds = torch.argmax(preds, dim=1)
print(preds.shape)
confidences = {class_to_idx[i]: float(preds[0][i]) for i in range(len(classes))}
return confidences
# prediction = model(image)[0]
# prediction = prediction.tolist()
# print(prediction)
# return {
# class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
# }
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=10))
if __name__ == "__main__":
demo.launch(debug=True)
|