miculpionier's picture
Update app.py
66e49d4
raw
history blame
No virus
1.48 kB
import gradio
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image
processor = ViltProcessor.from_pretrained("vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("vilt-b32-finetuned-vqa")
def predict_answer(image, question):
image = Image.fromarray(image.astype('uint8'), 'RGB')
encoding = processor(image, question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
probs = logits.softmax(dim=-1)
sorted_probs, sorted_indices = probs[0].sort(descending=True)
answer_list = []
for i in range(5):
prob = sorted_probs[i].item()
if prob > 0.00:
idx = sorted_indices[i].item()
answer = model.config.id2label[idx]
answer_list.append(f"{answer}: {prob:.2%}")
return answer_list
inputs = [
gradio.components.Image(label="Image"),
gradio.components.Textbox(label="Question", placeholder="Enter your question here.")
]
outputs = [
gradio.components.Textbox(label="Answer 1"),
gradio.components.Textbox(label="Answer 2"),
gradio.components.Textbox(label="Answer 3"),
gradio.components.Textbox(label="Answer 4"),
gradio.components.Textbox(label="Answer 5")
]
title = "Visual Question Answering (vilt-b32-finetuned-vqa)"
gradio.Interface(fn=predict_answer, inputs=inputs, outputs=outputs, title=title, allow_flagging="never",
css="footer{display:none !important}").launch()