clipdemo / app.py
keyishen's picture
Update app.py
41871e7 verified
import gradio as gr
from transformers import AutoProcessor, CLIPModel
clip_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(clip_path).eval()
processor = AutoProcessor.from_pretrained(clip_path)
clip_path2 = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
model2 = CLIPModel.from_pretrained(clip_path2).eval()
processor2 = AutoProcessor.from_pretrained(clip_path2)
async def predict(init_image, labels_level1):
if init_image is None:
return "", ""
split_labels = labels_level1.split(",")
ret_str = ""
inputs = processor(
text=split_labels, images=init_image, return_tensors="pt", padding=True
)
inputs2 = processor2(
text=split_labels, images=init_image, return_tensors="pt", padding=True
)
outputs = model(**inputs)
outputs2 = model2(**inputs2)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
logits_per_image2 = outputs2.logits_per_image # this is the image-text similarity score
for i in range(len(split_labels)):
ret_str += split_labels[i] + ": " + format(float(logits_per_image[0][i]), ".2f") + ", " + format(float(logits_per_image2[0][i]), ".2f") + "\n"
return ret_str, ret_str
css = """
#container{
margin: 0 auto;
max-width: 80rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
init_image_state = gr.State()
with gr.Column(elem_id="container"):
gr.Markdown(
"""# Clip Demo
""",
elem_id="intro",
)
with gr.Row():
txt_input = gr.Textbox(
value="cartoon,painting,screenshot",
interactive=True, label="设定大类别类别", scale=5)
txt = gr.Textbox(value="", label="Output:", scale=5)
generate_bt = gr.Button("点击开始分类", scale=1)
with gr.Row():
with gr.Column():
image_input = gr.Image(
sources=["upload", "clipboard"],
label="User Image",
type="pil",
)
with gr.Row():
prob_label = gr.Textbox(value="", label="一级分类")
inputs = [image_input, txt_input]
generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True)
image_input.change(
fn=predict,
inputs=inputs,
outputs=[txt, prob_label],
show_progress=True,
queue=False,
)
demo.queue().launch()