|
import gradio as gr |
|
from IJEPA_finetune import ViTIJEPA |
|
import torch |
|
from einops import rearrange |
|
from torchvision.transforms import Compose |
|
import torchvision |
|
|
|
classes = ['Acanthostichus', |
|
'Aenictus', |
|
'Amblyopone', |
|
'Attini', |
|
'Bothriomyrmecini', |
|
'Camponotini', |
|
'Cerapachys', |
|
'Cheliomyrmex', |
|
'Crematogastrini', |
|
'Cylindromyrmex', |
|
'Dolichoderini', |
|
'Dorylus', |
|
'Eciton', |
|
'Ectatommini', |
|
'Formicini', |
|
'Fulakora', |
|
'Gesomyrmecini', |
|
'Gigantiopini', |
|
'Heteroponerini', |
|
'Labidus', |
|
'Lasiini', |
|
'Leptomyrmecini', |
|
'Lioponera', |
|
'Melophorini', |
|
'Myopopone', |
|
'Myrmecia', |
|
'Myrmelachistini', |
|
'Myrmicini', |
|
'Myrmoteratini', |
|
'Mystrium', |
|
'Neivamyrmex', |
|
'Nomamyrmex', |
|
'Oecophyllini', |
|
'Ooceraea', |
|
'Paraponera', |
|
'Parasyscia', |
|
'Plagiolepidini', |
|
'Platythyreini', |
|
'Pogonomyrmecini', |
|
'Ponerini', |
|
'Prionopelta', |
|
'Probolomyrmecini', |
|
'Proceratiini', |
|
'Pseudomyrmex', |
|
'Solenopsidini', |
|
'Stenammini', |
|
'Stigmatomma', |
|
'Syscia', |
|
'Tapinomini', |
|
'Tetraponera', |
|
'Zasphinctus'] |
|
class_to_idx = {idx: cls for idx, cls in enumerate(classes)} |
|
|
|
tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)]) |
|
|
|
model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes)) |
|
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu'))) |
|
|
|
|
|
def ant_genus_classification(image): |
|
image = torch.Tensor(image) |
|
image = image.unsqueeze(0) |
|
image = rearrange(image, 'b h w c -> b c h w') |
|
image = tf(image) |
|
|
|
print(image.shape) |
|
with torch.no_grad(): |
|
prediction = torch.nn.functional.softmax(model(image)[0], dim=0) |
|
|
|
confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))} |
|
return confidences |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3)) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|