File size: 2,609 Bytes
f9165b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b10a39
 
 
 
 
 
 
f9165b2
 
 
 
 
 
fce23d9
f9165b2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import sys
import json
import torch
import kelip
import gradio as gr

def load_model():
    model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    model.eval()

    model_dict = {'model': model,
                  'preprocess_img': preprocess_img,
                  'tokenizer': tokenizer
                  }
    return model_dict

def classify(img, user_text):
    preprocess_img = model_dict['preprocess_img']

    input_img = preprocess_img(img).unsqueeze(0)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    input_img = input_img.to(device)

    # extract image features
    with torch.no_grad():
        image_features = model_dict['model'].encode_image(input_img)

        # extract text features
        user_texts = user_text.split(',')
        if user_text == '' or user_text.isspace():
            user_texts = []

        input_texts = model_dict['tokenizer'].encode(user_texts)
        if torch.cuda.is_available():
            input_texts = input_texts.cuda()
        text_features = model_dict['model'].encode_text(input_texts)

    # l2 normalize
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(len(user_texts))
    result = {}
    for value, index in zip(values, indices):
        result[user_texts[index]] = value.item()

    return result

if __name__ == '__main__':
    global model_dict

    model_dict = load_model()

    inputs = [gr.inputs.Image(type="pil", label="Image"),
              gr.inputs.Textbox(lines=5, label="Caption"),
              ]

    outputs = ['label']

    title = "KELIP"

    if torch.cuda.is_available():
        demo_status = "Demo is running on GPU"
    else:
        demo_status = "Demo is running on CPU"
    description = f"Details: paper_url. {demo_status}"
    examples = [
    ["squid_sundae.jpg", "์˜ค์ง•์–ด ์ˆœ๋Œ€,๊น€๋ฐฅ,์ˆœ๋Œ€,๋–ก๋ณถ์ด"],
    ["seokchon_lake.jpg", "ํ‰ํ™”์˜๋ฌธ,์˜ฌ๋ฆผํ”ฝ๊ณต์›,๋กฏ๋ฐ์›”๋“œ,์„์ดŒํ˜ธ์ˆ˜"],
    ["seokchon_lake.jpg", "๋ด„,์—ฌ๋ฆ„,๊ฐ€์„,๊ฒจ์šธ"],
    ["hwangchil_tree.jpg", "ํ™ฉ์น  ๋‚˜๋ฌด ๋ฌ˜๋ชฉ,ํ™ฉ์น  ๋‚˜๋ฌด,๋‚œ,์†Œ๋‚˜๋ฌด ๋ฌ˜๋ชฉ,์•ผ์ž์ˆ˜"],
    ]
    
    article = ""

    iface=gr.Interface(
        fn=classify,
        inputs=inputs,
        outputs=outputs,
        examples=examples,
        title=title,
        description=description,
        article=article
    )
    iface.launch()