Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import json | |
import torch | |
import kelip | |
import gradio as gr | |
def load_model(): | |
model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
model.eval() | |
model_dict = {'model': model, | |
'preprocess_img': preprocess_img, | |
'tokenizer': tokenizer | |
} | |
return model_dict | |
def classify(img, user_text): | |
preprocess_img = model_dict['preprocess_img'] | |
input_img = preprocess_img(img).unsqueeze(0) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
input_img = input_img.to(device) | |
# extract image features | |
with torch.no_grad(): | |
image_features = model_dict['model'].encode_image(input_img) | |
# extract text features | |
user_texts = user_text.split(',') | |
if user_text == '' or user_text.isspace(): | |
user_texts = [] | |
input_texts = model_dict['tokenizer'].encode(user_texts) | |
if torch.cuda.is_available(): | |
input_texts = input_texts.cuda() | |
text_features = model_dict['model'].encode_text(input_texts) | |
# l2 normalize | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
values, indices = similarity[0].topk(len(user_texts)) | |
result = {} | |
for value, index in zip(values, indices): | |
result[user_texts[index]] = value.item() | |
return result | |
if __name__ == '__main__': | |
global model_dict | |
model_dict = load_model() | |
inputs = [gr.inputs.Image(type="pil", label="Image"), | |
gr.inputs.Textbox(lines=5, label="Caption"), | |
] | |
outputs = ['label'] | |
title = "KELIP" | |
if torch.cuda.is_available(): | |
demo_status = "Demo is running on GPU" | |
else: | |
demo_status = "Demo is running on CPU" | |
description = f"Details: paper_url. {demo_status}" | |
examples = [ | |
["squid_sundae.jpg", "์ค์ง์ด ์๋,๊น๋ฐฅ,์๋,๋ก๋ณถ์ด"], | |
["seokchon_lake.jpg", "ํํ์๋ฌธ,์ฌ๋ฆผํฝ๊ณต์,๋กฏ๋ฐ์๋,์์ดํธ์"], | |
["seokchon_lake.jpg", "๋ด,์ฌ๋ฆ,๊ฐ์,๊ฒจ์ธ"], | |
["hwangchil_tree.jpg", "ํฉ์น ๋๋ฌด ๋ฌ๋ชฉ,ํฉ์น ๋๋ฌด,๋,์๋๋ฌด ๋ฌ๋ชฉ,์ผ์์"], | |
] | |
article = "" | |
iface=gr.Interface( | |
fn=classify, | |
inputs=inputs, | |
outputs=outputs, | |
examples=examples, | |
title=title, | |
description=description, | |
article=article | |
) | |
iface.launch() |