Spaces:
Runtime error
Runtime error
File size: 1,483 Bytes
1bb90fa 2c140eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import data
import torch
import gradio as gr
from models import imagebind_model
from models.imagebind_model import ModalityType
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def image_text_zeroshot(image, text_list):
image_paths = [image]
labels = [label.strip(" ") for label in text_list.strip(" ").split(",")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = torch.softmax(
embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T,
dim=-1
).squeeze(0).tolist()
score_dict = {label:score for label, score in zip(labels, scores)}
return score_dict
inputs = [
gr.inputs.Image(type='file',
label="Input image"),
gr.inputs.Textbox(lines=1,
label="Candidate texts"),
]
iface = gr.Interface(image_text_zeroshot,
inputs,
"label",
examples=[[".assets/dog_image.jpg", "A dog|A car|A bird"],
[".assets/car_image.jpg", "A dog|A car|A bird"],
[".assets/bird_image.jpg", "A dog|A car|A bird"]],
description="""Zeroshot test""",
title="Zero-shot Classification")
iface.launch() |