|
import gradio as gr |
|
from transformers import AutoProcessor, CLIPModel |
|
|
|
clip_path = "openai/clip-vit-base-patch32" |
|
model = CLIPModel.from_pretrained(clip_path).eval() |
|
processor = AutoProcessor.from_pretrained(clip_path) |
|
|
|
|
|
clip_path2 = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" |
|
model2 = CLIPModel.from_pretrained(clip_path2).eval() |
|
processor2 = AutoProcessor.from_pretrained(clip_path2) |
|
|
|
async def predict(init_image, labels_level1): |
|
if init_image is None: |
|
return "", "" |
|
|
|
split_labels = labels_level1.split(",") |
|
ret_str = "" |
|
|
|
inputs = processor( |
|
text=split_labels, images=init_image, return_tensors="pt", padding=True |
|
) |
|
|
|
inputs2 = processor2( |
|
text=split_labels, images=init_image, return_tensors="pt", padding=True |
|
) |
|
outputs = model(**inputs) |
|
outputs2 = model2(**inputs2) |
|
logits_per_image = outputs.logits_per_image |
|
|
|
logits_per_image2 = outputs2.logits_per_image |
|
|
|
for i in range(len(split_labels)): |
|
ret_str += split_labels[i] + ": " + format(float(logits_per_image[0][i]), ".2f") + ", " + format(float(logits_per_image2[0][i]), ".2f") + "\n" |
|
|
|
return ret_str, ret_str |
|
|
|
|
|
css = """ |
|
#container{ |
|
margin: 0 auto; |
|
max-width: 80rem; |
|
} |
|
#intro{ |
|
max-width: 100%; |
|
text-align: center; |
|
margin: 0 auto; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
init_image_state = gr.State() |
|
with gr.Column(elem_id="container"): |
|
gr.Markdown( |
|
"""# Clip Demo |
|
""", |
|
elem_id="intro", |
|
) |
|
with gr.Row(): |
|
txt_input = gr.Textbox( |
|
value="cartoon,painting,screenshot", |
|
interactive=True, label="设定大类别类别", scale=5) |
|
txt = gr.Textbox(value="", label="Output:", scale=5) |
|
generate_bt = gr.Button("点击开始分类", scale=1) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
sources=["upload", "clipboard"], |
|
label="User Image", |
|
type="pil", |
|
) |
|
with gr.Row(): |
|
prob_label = gr.Textbox(value="", label="一级分类") |
|
|
|
inputs = [image_input, txt_input] |
|
generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True) |
|
image_input.change( |
|
fn=predict, |
|
inputs=inputs, |
|
outputs=[txt, prob_label], |
|
show_progress=True, |
|
queue=False, |
|
) |
|
|
|
demo.queue().launch() |
|
|