import gradio as gr from transformers import pipeline from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import CLIPProcessor, CLIPModel import torch from PIL import Image import requests import os import random device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "openai/clip-vit-base-patch16" # You can choose a different CLIP model from Hugging Face clipprocessor = CLIPProcessor.from_pretrained(model_id) clipmodel = CLIPModel.from_pretrained(model_id).to(device) model_id = "Salesforce/blip-image-captioning-base" ## load modelID for BLIP blipmodel = BlipForConditionalGeneration.from_pretrained(model_id) blipprocessor = BlipProcessor.from_pretrained(model_id) im_dir = os.path.join(os.getcwd(),'images') def sample_image(im_dir=im_dir): all_ims = os.listdir(im_dir) new_im = random.choice(all_ims) return gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,new_im),height=500),gr.Textbox(label="Image fname",value=new_im,interactive=False, visible=False) def evaluate_caption(image, caption): # # Pre-process image # image = processor(images=image, return_tensors="pt").to(device) # # Tokenize and encode the caption # text = processor(text=caption, return_tensors="pt").to(device) blip_input = blipprocessor(image, return_tensors="pt") out = blipmodel.generate(**blip_input,max_new_tokens=50) blip_caption = blipprocessor.decode(out[0], skip_special_tokens=True) inputs = clipprocessor(text=[caption,blip_caption], images=image, return_tensors="pt", padding=True) similarity_score = clipmodel(**inputs).logits_per_image # Convert score to a float score = similarity_score.softmax(dim=1).detach().numpy() print(score) if score[0][0]>score[0][1]: winner = "The first caption is the human" else: winner = "The second caption is the human" return blip_caption,winner # ,gr.Image(type="pil", value="mukherjee_kushin_WIDPICS1.jpg") callback = gr.HuggingFaceDatasetSaver('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', "gradioTest") with gr.Blocks() as demo: im_path_str = 'n01677366_12918.JPEG' im_path = gr.Textbox(label="Image fname",value=im_path_str,interactive=False, visible=False) # fn=evaluate_caption, # inputs=["image", "text"] with gr.Column(): im = gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,im_path_str),height=500) caps = gr.Textbox(label="Player 1 Caption") submit_btn = gr.Button("Submit!!") # outputs=["text","text"], with gr.Column(): out1 = gr.Textbox(label="Player 2 (Machine) Caption",interactive=False) out2 = gr.Textbox(label="Winner",interactive=False) reload_btn = gr.Button("Next Image") # live=False, # interpretation="default" callback.setup([caps, out1, out2, im_path], "flagged_data_points") # callback.flag([image, caption, blip_caption, winner]) submit_btn.click(fn = evaluate_caption,inputs = [im,caps], outputs = [out1, out2],api_name="test").success(lambda *args: callback.flag(args), [caps, out1, out2, im_path], None, preprocess=False) reload_btn.click(fn = sample_image, inputs=None, outputs = [im,im_path] ) # with gr.Row(): # btn = gr.Button("Flag") # btn.click(lambda *args: callback.flag(args), [im, caps, out1, out2], None, preprocess=False) demo.launch()