import gradio from transformers import ViltProcessor, ViltForQuestionAnswering from PIL import Image processor = ViltProcessor.from_pretrained("vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("vilt-b32-finetuned-vqa") def predict_answer(image, question): image = Image.fromarray(image.astype('uint8'), 'RGB') encoding = processor(image, question, return_tensors="pt") outputs = model(**encoding) logits = outputs.logits probs = logits.softmax(dim=-1) sorted_probs, sorted_indices = probs[0].sort(descending=True) answer_list = [] for i in range(5): prob = sorted_probs[i].item() if prob > 0.00: idx = sorted_indices[i].item() answer = model.config.id2label[idx] answer_list.append(f"{answer}: {prob:.2%}") return answer_list inputs = [ gradio.components.Image(label="Image"), gradio.components.Textbox(label="Question", placeholder="Enter your question here.") ] outputs = [ gradio.components.Textbox(label="Answer 1"), gradio.components.Textbox(label="Answer 2"), gradio.components.Textbox(label="Answer 3"), gradio.components.Textbox(label="Answer 4"), gradio.components.Textbox(label="Answer 5") ] title = "Visual Question Answering (vilt-b32-finetuned-vqa)" gradio.Interface(fn=predict_answer, inputs=inputs, outputs=outputs, title=title, allow_flagging="never", css="footer{display:none !important}").launch()