File size: 3,318 Bytes
88ae77c
 
 
 
 
 
 
5f73c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88ae77c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b44e693
 
 
393b535
b44e693
 
 
 
88ae77c
 
 
 
 
 
b44e693
88ae77c
 
 
3798706
b687777
 
9a82fe2
b44e693
be1fa74
88ae77c
 
 
 
 
 
 
 
 
adbabdb
88ae77c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from IJEPA_finetune import ViTIJEPA
import torch
from einops import rearrange
from torchvision.transforms import Compose
import torchvision

class CMYKToRGB(object):
    def __call__(self, img):
        # Ensure that the input is a PIL Image
        if not isinstance(img, torch.Tensor):
            raise TypeError("Input image should be a torch.Tensor")

        # Ensure that the image has 4 channels (CMYK)
        if img.shape[-3] != 4:
            return img

        # Extract CMYK channels
        c, m, y, k = img.unbind(-3)

        # CMYK to RGB transformation
        r = 255 * (1 - c) * (1 - k)
        g = 255 * (1 - m) * (1 - k)
        b = 255 * (1 - y) * (1 - k)

        # Stack RGB channels
        rgb_img = torch.stack([r, g, b], dim=-3)

        return rgb_img

classes = ['Acanthostichus',
           'Aenictus',
           'Amblyopone',
           'Attini',
           'Bothriomyrmecini',
           'Camponotini',
           'Cerapachys',
           'Cheliomyrmex',
           'Crematogastrini',
           'Cylindromyrmex',
           'Dolichoderini',
           'Dorylus',
           'Eciton',
           'Ectatommini',
           'Formicini',
           'Fulakora',
           'Gesomyrmecini',
           'Gigantiopini',
           'Heteroponerini',
           'Labidus',
           'Lasiini',
           'Leptomyrmecini',
           'Lioponera',
           'Melophorini',
           'Myopopone',
           'Myrmecia',
           'Myrmelachistini',
           'Myrmicini',
           'Myrmoteratini',
           'Mystrium',
           'Neivamyrmex',
           'Nomamyrmex',
           'Oecophyllini',
           'Ooceraea',
           'Paraponera',
           'Parasyscia',
           'Plagiolepidini',
           'Platythyreini',
           'Pogonomyrmecini',
           'Ponerini',
           'Prionopelta',
           'Probolomyrmecini',
           'Proceratiini',
           'Pseudomyrmex',
           'Solenopsidini',
           'Stenammini',
           'Stigmatomma',
           'Syscia',
           'Tapinomini',
           'Tetraponera',
           'Zasphinctus']
class_to_idx = {idx: cls for idx, cls in enumerate(classes)}

train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((64, 64), antialias=True),
        CMYKToRGB(),
    ]
)


model = ViTIJEPA(64, 4, 3, 64, 8, 8, len(classes))
model.load_state_dict(torch.load("vit_ijepa_ant_1.pt", map_location=torch.device('cpu')))


def ant_genus_classification(image):
    image = train_transforms(image).unsqueeze(0)

    print(image.shape)
    with torch.no_grad():
        y_hat = model(image)
        preds = torch.nn.functional.softmax(y_hat, dim=1)
        # preds = torch.argmax(preds, dim=1)
        print(preds.shape)
        
        confidences = {class_to_idx[i]: float(preds[0][i]) for i in range(len(classes))}
    return confidences
    # prediction = model(image)[0]
    # prediction = prediction.tolist()
    # print(prediction)
    # return {
    #     class_to_idx[i]: prediction[i] for i in range(len(prediction)) if prediction[i] > 0.01
    # }


demo = gr.Interface(fn=ant_genus_classification, inputs="image", outputs=gr.Label(num_top_classes=10))

if __name__ == "__main__":
    demo.launch(debug=True)