Spaces:
Runtime error
Runtime error
File size: 5,052 Bytes
68a96d1 8c8398d 68a96d1 4c2e88b 68a96d1 1bf9e8f 68a96d1 9ce67d0 68a96d1 9ce67d0 68a96d1 627fbe3 68a96d1 ffec48e 68a96d1 4c7ba7b 68a96d1 627fbe3 68a96d1 ffec48e 68a96d1 627fbe3 68a96d1 627fbe3 68a96d1 ffec48e 68a96d1 0865f24 68a96d1 0865f24 68a96d1 0865f24 68a96d1 0865f24 68a96d1 7518be4 68a96d1 7518be4 68a96d1 0799604 7518be4 68a96d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")
blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
git_model_base.to(device)
blip_model_base.to(device)
git_model_large.to(device)
blip_model_large.to(device)
# vilt_model.to(device)
def generate_answer_git(processor, model, image, question):
# prepare image
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# prepare question
input_ids = processor(text=question, add_special_tokens=False).input_ids
input_ids = [processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=128)#50)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer[0].replace(question, '').replace(question.lower(), '').strip()
def generate_answer_blip(processor, model, image, question):
# prepare image + question
inputs = processor(images=image, text=question, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=128)#50)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer[0].strip()
def generate_answer_vilt(processor, model, image, question):
# prepare image + question
encoding = processor(images=image, text=question, max_length=128, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
predicted_class_idx = outputs.logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]#[0].strip()
def generate_answers(image, question):
answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)
answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)
answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)
answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
# answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)
return answer_git_base, answer_git_large, answer_blip_base, answer_blip_large#, answer_vilt
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by GIT-large"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by BLIP-large")]#, gr.outputs.Textbox(label="Answer generated by ViLT")]
title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_answers,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True) |