File size: 2,582 Bytes
88ae77c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from IJEPA_finetune import ViTIJEPA
import torch
from einops import rearrange
from torchvision.transforms import Compose
import torchvision

classes = ['Acanthostichus',
           'Aenictus',
           'Amblyopone',
           'Attini',
           'Bothriomyrmecini',
           'Camponotini',
           'Cerapachys',
           'Cheliomyrmex',
           'Crematogastrini',
           'Cylindromyrmex',
           'Dolichoderini',
           'Dorylus',
           'Eciton',
           'Ectatommini',
           'Formicini',
           'Fulakora',
           'Gesomyrmecini',
           'Gigantiopini',
           'Heteroponerini',
           'Labidus',
           'Lasiini',
           'Leptomyrmecini',
           'Lioponera',
           'Melophorini',
           'Myopopone',
           'Myrmecia',
           'Myrmelachistini',
           'Myrmicini',
           'Myrmoteratini',
           'Mystrium',
           'Neivamyrmex',
           'Nomamyrmex',
           'Oecophyllini',
           'Ooceraea',
           'Paraponera',
           'Parasyscia',
           'Plagiolepidini',
           'Platythyreini',
           'Pogonomyrmecini',
           'Ponerini',
           'Prionopelta',
           'Probolomyrmecini',
           'Proceratiini',
           'Pseudomyrmex',
           'Solenopsidini',
           'Stenammini',
           'Stigmatomma',
           'Syscia',
           'Tapinomini',
           'Tetraponera',
           'Zasphinctus']
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}

tf = Compose([torchvision.transforms.Resize((64, 64), antialias=True)])

model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))


def ant_genus_classification(image):
    image = torch.Tensor(image)
    image = image.unsqueeze(0)
    image = rearrange(image, 'b h w c -> b c h w')
    image = tf(image)

    print(image.shape)
    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(image)[0], dim=0)
        # print(prediction.tolist())
        confidences = {class_to_idx[i]: float(prediction[i]) for i in range(len(classes))}
    return confidences
    # prediction = model(image)[0]
    # prediction = prediction.tolist()
    # print(prediction)
    # return {
    #     class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
    # }


demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=3))

if __name__ == "__main__":
    demo.launch(debug=True)