Spaces:
Runtime error
Runtime error
File size: 2,609 Bytes
f9165b2 8b10a39 f9165b2 fce23d9 f9165b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import sys
import json
import torch
import kelip
import gradio as gr
def load_model():
model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
model_dict = {'model': model,
'preprocess_img': preprocess_img,
'tokenizer': tokenizer
}
return model_dict
def classify(img, user_text):
preprocess_img = model_dict['preprocess_img']
input_img = preprocess_img(img).unsqueeze(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
input_img = input_img.to(device)
# extract image features
with torch.no_grad():
image_features = model_dict['model'].encode_image(input_img)
# extract text features
user_texts = user_text.split(',')
if user_text == '' or user_text.isspace():
user_texts = []
input_texts = model_dict['tokenizer'].encode(user_texts)
if torch.cuda.is_available():
input_texts = input_texts.cuda()
text_features = model_dict['model'].encode_text(input_texts)
# l2 normalize
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(len(user_texts))
result = {}
for value, index in zip(values, indices):
result[user_texts[index]] = value.item()
return result
if __name__ == '__main__':
global model_dict
model_dict = load_model()
inputs = [gr.inputs.Image(type="pil", label="Image"),
gr.inputs.Textbox(lines=5, label="Caption"),
]
outputs = ['label']
title = "KELIP"
if torch.cuda.is_available():
demo_status = "Demo is running on GPU"
else:
demo_status = "Demo is running on CPU"
description = f"Details: paper_url. {demo_status}"
examples = [
["squid_sundae.jpg", "์ค์ง์ด ์๋,๊น๋ฐฅ,์๋,๋ก๋ณถ์ด"],
["seokchon_lake.jpg", "ํํ์๋ฌธ,์ฌ๋ฆผํฝ๊ณต์,๋กฏ๋ฐ์๋,์์ดํธ์"],
["seokchon_lake.jpg", "๋ด,์ฌ๋ฆ,๊ฐ์,๊ฒจ์ธ"],
["hwangchil_tree.jpg", "ํฉ์น ๋๋ฌด ๋ฌ๋ชฉ,ํฉ์น ๋๋ฌด,๋,์๋๋ฌด ๋ฌ๋ชฉ,์ผ์์"],
]
article = ""
iface=gr.Interface(
fn=classify,
inputs=inputs,
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article
)
iface.launch() |