File size: 2,582 Bytes
88ae77c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
from IJEPA_finetune import ViTIJEPA
import torch
from einops import rearrange
from torchvision.transforms import Compose
import torchvision
classes = ['Acanthostichus',
'Aenictus',
'Amblyopone',
'Attini',
'Bothriomyrmecini',
'Camponotini',
'Cerapachys',
'Cheliomyrmex',
'Crematogastrini',
'Cylindromyrmex',
'Dolichoderini',
'Dorylus',
'Eciton',
'Ectatommini',
'Formicini',
'Fulakora',
'Gesomyrmecini',
'Gigantiopini',
'Heteroponerini',
'Labidus',
'Lasiini',
'Leptomyrmecini',
'Lioponera',
'Melophorini',
'Myopopone',
'Myrmecia',
'Myrmelachistini',
'Myrmicini',
'Myrmoteratini',
'Mystrium',
'Neivamyrmex',
'Nomamyrmex',
'Oecophyllini',
'Ooceraea',
'Paraponera',
'Parasyscia',
'Plagiolepidini',
'Platythyreini',
'Pogonomyrmecini',
'Ponerini',
'Prionopelta',
'Probolomyrmecini',
'Proceratiini',
'Pseudomyrmex',
'Solenopsidini',
'Stenammini',
'Stigmatomma',
'Syscia',
'Tapinomini',
'Tetraponera',
'Zasphinctus']
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}
tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)])
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))
def ant_genus_classification(image):
image = torch.Tensor(image)
image = image.unsqueeze(0)
image = rearrange(image, 'b h w c -> b c h w')
image = tf(image)
print(image.shape)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(image)[0], dim=0)
# print(prediction.tolist())
confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))}
return confidences
# prediction = model(image)[0]
# prediction = prediction.tolist()
# print(prediction)
# return {
# class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
# }
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch(debug=True)
|