Spaces:
Sleeping
Sleeping
import gradio as gr | |
from IJEPA_finetune import ViTIJEPA | |
import torch | |
from einops import rearrange | |
from torchvision.transforms import Compose | |
import torchvision | |
classes = ['Acanthostichus', | |
'Aenictus', | |
'Amblyopone', | |
'Attini', | |
'Bothriomyrmecini', | |
'Camponotini', | |
'Cerapachys', | |
'Cheliomyrmex', | |
'Crematogastrini', | |
'Cylindromyrmex', | |
'Dolichoderini', | |
'Dorylus', | |
'Eciton', | |
'Ectatommini', | |
'Formicini', | |
'Fulakora', | |
'Gesomyrmecini', | |
'Gigantiopini', | |
'Heteroponerini', | |
'Labidus', | |
'Lasiini', | |
'Leptomyrmecini', | |
'Lioponera', | |
'Melophorini', | |
'Myopopone', | |
'Myrmecia', | |
'Myrmelachistini', | |
'Myrmicini', | |
'Myrmoteratini', | |
'Mystrium', | |
'Neivamyrmex', | |
'Nomamyrmex', | |
'Oecophyllini', | |
'Ooceraea', | |
'Paraponera', | |
'Parasyscia', | |
'Plagiolepidini', | |
'Platythyreini', | |
'Pogonomyrmecini', | |
'Ponerini', | |
'Prionopelta', | |
'Probolomyrmecini', | |
'Proceratiini', | |
'Pseudomyrmex', | |
'Solenopsidini', | |
'Stenammini', | |
'Stigmatomma', | |
'Syscia', | |
'Tapinomini', | |
'Tetraponera', | |
'Zasphinctus'] | |
class_to_idx = {idx: cls for idx, cls in enumerate(classes)} | |
tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)]) | |
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) | |
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) | |
def ant_genus_classification(image): | |
image = torch.Tensor(image) | |
image = image.unsqueeze(0) | |
image = rearrange(image, 'b h w c -> b c h w') | |
image = tf(image) | |
print(image.shape) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(image)[0], dim=0) | |
# print(prediction.tolist()) | |
confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))} | |
return confidences | |
# prediction = model(image)[0] | |
# prediction = prediction.tolist() | |
# print(prediction) | |
# return { | |
# class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01 | |
# } | |
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3)) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |