Spidartist's picture
Update app.py
adbabdb
raw
history blame
3.32 kB
import gradio as gr
from IJEPA_finetune import ViTIJEPA
import torch
from einops import rearrange
from torchvision.transforms import Compose
import torchvision
class CMYKToRGB(object):
def __call__(self, img):
# Ensure that the input is a PIL Image
if not isinstance(img, torch.Tensor):
raise TypeError("Input image should be a torch.Tensor")
# Ensure that the image has 4 channels (CMYK)
if img.shape[-3] != 4:
return img
# Extract CMYK channels
c, m, y, k = img.unbind(-3)
# CMYK to RGB transformation
r = 255 * (1 - c) * (1 - k)
g = 255 * (1 - m) * (1 - k)
b = 255 * (1 - y) * (1 - k)
# Stack RGB channels
rgb_img = torch.stack([r, g, b], dim=-3)
return rgb_img
classes = ['Acanthostichus',
'Aenictus',
'Amblyopone',
'Attini',
'Bothriomyrmecini',
'Camponotini',
'Cerapachys',
'Cheliomyrmex',
'Crematogastrini',
'Cylindromyrmex',
'Dolichoderini',
'Dorylus',
'Eciton',
'Ectatommini',
'Formicini',
'Fulakora',
'Gesomyrmecini',
'Gigantiopini',
'Heteroponerini',
'Labidus',
'Lasiini',
'Leptomyrmecini',
'Lioponera',
'Melophorini',
'Myopopone',
'Myrmecia',
'Myrmelachistini',
'Myrmicini',
'Myrmoteratini',
'Mystrium',
'Neivamyrmex',
'Nomamyrmex',
'Oecophyllini',
'Ooceraea',
'Paraponera',
'Parasyscia',
'Plagiolepidini',
'Platythyreini',
'Pogonomyrmecini',
'Ponerini',
'Prionopelta',
'Probolomyrmecini',
'Proceratiini',
'Pseudomyrmex',
'Solenopsidini',
'Stenammini',
'Stigmatomma',
'Syscia',
'Tapinomini',
'Tetraponera',
'Zasphinctus']
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}
train_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize((64, 64), antialias=True),
CMYKToRGB(),
]
)
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))
def ant_genus_classification(image):
image = train_transforms(image).unsqueeze(0)
print(image.shape)
with torch.no_grad():
y_hat = model(image)
preds = torch.nn.functional.softmax(y_hat, dim=1)
# preds = torch.argmax(preds, dim=1)
print(preds.shape)
confidences = {class_to_idx[i]: float(preds[0][i]) for i in range(len(classes))}
return confidences
# prediction = model(image)[0]
# prediction = prediction.tolist()
# print(prediction)
# return {
# class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
# }
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=10))
if __name__ == "__main__":
demo.launch(debug=True)