SigLIP_Tagger / app.py
p1atdev's picture
chore: not to compile
9126ead
raw
history blame
3.11 kB
import os
from PIL import Image
import numpy as np
import torch
from transformers import (
AutoImageProcessor,
)
import gradio as gr
from modeling_siglip import SiglipForImageClassification
MODEL_NAME = os.environ["MODEL_NAME"]
PROCESSOR_NAME = MODEL_NAME
HF_TOKEN = os.environ["HF_READ_TOKEN"]
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
README_MD = """\
## SigLIP Tagger Test 3
An experimental model for tagging danbooru tags of images using SigLIP.
Model(s):
- [p1atdev/siglip-tagger-test-3](https://huggingface.co/p1atdev/siglip-tagger-test-3)
Example images by NovelAI and niji・journey.
"""
model = SiglipForImageClassification.from_pretrained(MODEL_NAME, token=HF_TOKEN)
processor = AutoImageProcessor.from_pretrained(PROCESSOR_NAME, token=HF_TOKEN)
def compose_text(results: dict[str, float], threshold: float = 0.3):
return ", ".join(
[
key
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
if value > threshold
]
)
@torch.no_grad()
def predict_tags(image: Image.Image, threshold: float):
inputs = processor(image, return_tensors="pt")
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
logits = np.clip(logits, 0.0, 1.0)
results = {}
for prediction in logits:
for i, prob in enumerate(prediction):
if prob.item() > 0:
results[model.config.id2label[i]] = prob.item()
return compose_text(results, threshold), results
css = """\
.sticky {
position: sticky;
top: 16px;
}
.gradio-container {
overflow: clip;
}
"""
def demo():
with gr.Blocks(css=css) as ui:
gr.Markdown(README_MD)
with gr.Row():
with gr.Column():
with gr.Row(elem_classes="sticky"):
with gr.Column():
input_img = gr.Image(
label="Input image", type="pil", height=480
)
with gr.Group():
tag_threshold_slider = gr.Slider(
label="Tags threshold",
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.01,
)
start_btn = gr.Button(value="Start", variant="primary")
gr.Examples(
examples=EXAMPLES,
inputs=[input_img],
cache_examples=False,
)
with gr.Column():
output_tags = gr.Text(label="Output text", interactive=False)
output_label = gr.Label(label="Output tags")
start_btn.click(
fn=predict_tags,
inputs=[input_img, tag_threshold_slider],
outputs=[output_tags, output_label],
)
ui.launch(
debug=True,
# share=True
)
if __name__ == "__main__":
demo()