import gradio as gr from transformers import AutoProcessor, CLIPModel clip_path = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(clip_path).eval() processor = AutoProcessor.from_pretrained(clip_path) clip_path2 = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" model2 = CLIPModel.from_pretrained(clip_path2).eval() processor2 = AutoProcessor.from_pretrained(clip_path2) async def predict(init_image, labels_level1): if init_image is None: return "", "" split_labels = labels_level1.split(",") ret_str = "" inputs = processor( text=split_labels, images=init_image, return_tensors="pt", padding=True ) inputs2 = processor2( text=split_labels, images=init_image, return_tensors="pt", padding=True ) outputs = model(**inputs) outputs2 = model2(**inputs2) logits_per_image = outputs.logits_per_image # this is the image-text similarity score logits_per_image2 = outputs2.logits_per_image # this is the image-text similarity score for i in range(len(split_labels)): ret_str += split_labels[i] + ": " + format(float(logits_per_image[0][i]), ".2f") + ", " + format(float(logits_per_image2[0][i]), ".2f") + "\n" return ret_str, ret_str css = """ #container{ margin: 0 auto; max-width: 80rem; } #intro{ max-width: 100%; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: init_image_state = gr.State() with gr.Column(elem_id="container"): gr.Markdown( """# Clip Demo """, elem_id="intro", ) with gr.Row(): txt_input = gr.Textbox( value="cartoon,painting,screenshot", interactive=True, label="设定大类别类别", scale=5) txt = gr.Textbox(value="", label="Output:", scale=5) generate_bt = gr.Button("点击开始分类", scale=1) with gr.Row(): with gr.Column(): image_input = gr.Image( sources=["upload", "clipboard"], label="User Image", type="pil", ) with gr.Row(): prob_label = gr.Textbox(value="", label="一级分类") inputs = [image_input, txt_input] generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True) image_input.change( fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True, queue=False, ) demo.queue().launch()