Spaces:
Runtime error
Runtime error
import gradio | |
from transformers import ViltProcessor, ViltForQuestionAnswering | |
from PIL import Image | |
processor = ViltProcessor.from_pretrained("vilt-b32-finetuned-vqa") | |
model = ViltForQuestionAnswering.from_pretrained("vilt-b32-finetuned-vqa") | |
def predict_answer(image, question): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
encoding = processor(image, question, return_tensors="pt") | |
outputs = model(**encoding) | |
logits = outputs.logits | |
probs = logits.softmax(dim=-1) | |
sorted_probs, sorted_indices = probs[0].sort(descending=True) | |
answer_list = [] | |
for i in range(5): | |
prob = sorted_probs[i].item() | |
if prob > 0.00: | |
idx = sorted_indices[i].item() | |
answer = model.config.id2label[idx] | |
answer_list.append(f"{answer}: {prob:.2%}") | |
return answer_list | |
inputs = [ | |
gradio.components.Image(label="Image"), | |
gradio.components.Textbox(label="Question", placeholder="Enter your question here.") | |
] | |
outputs = [ | |
gradio.components.Textbox(label="Answer 1"), | |
gradio.components.Textbox(label="Answer 2"), | |
gradio.components.Textbox(label="Answer 3"), | |
gradio.components.Textbox(label="Answer 4"), | |
gradio.components.Textbox(label="Answer 5") | |
] | |
title = "Visual Question Answering (vilt-b32-finetuned-vqa)" | |
gradio.Interface(fn=predict_answer, inputs=inputs, outputs=outputs, title=title, allow_flagging="never", | |
css="footer{display:none !important}").launch() | |