Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from transformers import CLIPProcessor, CLIPModel | |
import torch | |
from PIL import Image | |
import requests | |
import os | |
import random | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = "openai/clip-vit-base-patch16" # You can choose a different CLIP model from Hugging Face | |
clipprocessor = CLIPProcessor.from_pretrained(model_id) | |
clipmodel = CLIPModel.from_pretrained(model_id).to(device) | |
model_id = "Salesforce/blip-image-captioning-base" ## load modelID for BLIP | |
blipmodel = BlipForConditionalGeneration.from_pretrained(model_id) | |
blipprocessor = BlipProcessor.from_pretrained(model_id) | |
im_dir = os.path.join(os.getcwd(),'images') | |
def sample_image(im_dir=im_dir): | |
all_ims = os.listdir(im_dir) | |
new_im = random.choice(all_ims) | |
return gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,new_im),height=500),gr.Textbox(label="Image fname",value=new_im,interactive=False, visible=False) | |
def evaluate_caption(image, caption): | |
# # Pre-process image | |
# image = processor(images=image, return_tensors="pt").to(device) | |
# # Tokenize and encode the caption | |
# text = processor(text=caption, return_tensors="pt").to(device) | |
blip_input = blipprocessor(image, return_tensors="pt") | |
out = blipmodel.generate(**blip_input,max_new_tokens=50) | |
blip_caption = blipprocessor.decode(out[0], skip_special_tokens=True) | |
inputs = clipprocessor(text=[caption,blip_caption], images=image, return_tensors="pt", padding=True) | |
similarity_score = clipmodel(**inputs).logits_per_image | |
# Convert score to a float | |
score = similarity_score.softmax(dim=1).detach().numpy() | |
print(score) | |
if score[0][0]>score[0][1]: | |
winner = "The first caption is the human" | |
else: | |
winner = "The second caption is the human" | |
return blip_caption,winner | |
# ,gr.Image(type="pil", value="mukherjee_kushin_WIDPICS1.jpg") | |
callback = gr.HuggingFaceDatasetSaver('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', "gradioTest") | |
with gr.Blocks() as demo: | |
im_path_str = 'n01677366_12918.JPEG' | |
im_path = gr.Textbox(label="Image fname",value=im_path_str,interactive=False, visible=False) | |
# fn=evaluate_caption, | |
# inputs=["image", "text"] | |
with gr.Column(): | |
im = gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,im_path_str),height=500) | |
caps = gr.Textbox(label="Player 1 Caption") | |
submit_btn = gr.Button("Submit!!") | |
# outputs=["text","text"], | |
with gr.Column(): | |
out1 = gr.Textbox(label="Player 2 (Machine) Caption",interactive=False) | |
out2 = gr.Textbox(label="Winner",interactive=False) | |
reload_btn = gr.Button("Next Image") | |
# live=False, | |
# interpretation="default" | |
callback.setup([caps, out1, out2, im_path], "flagged_data_points") | |
# callback.flag([image, caption, blip_caption, winner]) | |
submit_btn.click(fn = evaluate_caption,inputs = [im,caps], outputs = [out1, out2],api_name="test").success(lambda *args: callback.flag(args), [caps, out1, out2, im_path], None, preprocess=False) | |
reload_btn.click(fn = sample_image, inputs=None, outputs = [im,im_path] ) | |
# with gr.Row(): | |
# btn = gr.Button("Flag") | |
# btn.click(lambda *args: callback.flag(args), [im, caps, out1, out2], None, preprocess=False) | |
demo.launch() | |