Spaces:
Runtime error
Runtime error
import data | |
import torch | |
import gradio as gr | |
from models import imagebind_model | |
from models.imagebind_model import ModalityType | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model = imagebind_model.imagebind_huge(pretrained=True) | |
model.eval() | |
model.to(device) | |
def image_text_zeroshot(image, text_list): | |
image_paths = [image] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split(",")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(text_list, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = torch.softmax( | |
embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, | |
dim=-1 | |
).squeeze(0).tolist() | |
score_dict = {label:score for label, score in zip(labels, scores)} | |
return score_dict | |
inputs = [ | |
gr.inputs.Image(type='file', | |
label="Input image"), | |
gr.inputs.Textbox(lines=1, | |
label="Candidate texts"), | |
] | |
iface = gr.Interface(image_text_zeroshot, | |
inputs, | |
"label", | |
examples=[[".assets/dog_image.jpg", "A dog|A car|A bird"], | |
[".assets/car_image.jpg", "A dog|A car|A bird"], | |
[".assets/bird_image.jpg", "A dog|A car|A bird"]], | |
description="""Zeroshot test""", | |
title="Zero-shot Classification") | |
iface.launch() |