Spaces:
Running
Running
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from transformers import ( | |
AutoImageProcessor, | |
) | |
import gradio as gr | |
from modeling_siglip import SiglipForImageClassification | |
MODEL_NAME = os.environ["MODEL_NAME"] | |
PROCESSOR_NAME = MODEL_NAME | |
HF_TOKEN = os.environ["HF_READ_TOKEN"] | |
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]] | |
README_MD = """\ | |
## SigLIP Tagger Test 3 | |
An experimental model for tagging danbooru tags of images using SigLIP. | |
Model(s): | |
- [p1atdev/siglip-tagger-test-3](https://huggingface.co/p1atdev/siglip-tagger-test-3) | |
Example images by NovelAI and niji・journey. | |
""" | |
model = SiglipForImageClassification.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
processor = AutoImageProcessor.from_pretrained(PROCESSOR_NAME, token=HF_TOKEN) | |
def compose_text(results: dict[str, float], threshold: float = 0.3): | |
return ", ".join( | |
[ | |
key | |
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True) | |
if value > threshold | |
] | |
) | |
def predict_tags(image: Image.Image, threshold: float): | |
inputs = processor(image, return_tensors="pt") | |
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu() | |
logits = np.clip(logits, 0.0, 1.0) | |
results = {} | |
for prediction in logits: | |
for i, prob in enumerate(prediction): | |
if prob.item() > 0: | |
results[model.config.id2label[i]] = prob.item() | |
return compose_text(results, threshold), results | |
css = """\ | |
.sticky { | |
position: sticky; | |
top: 16px; | |
} | |
.gradio-container { | |
overflow: clip; | |
} | |
""" | |
def demo(): | |
with gr.Blocks(css=css) as ui: | |
gr.Markdown(README_MD) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(elem_classes="sticky"): | |
with gr.Column(): | |
input_img = gr.Image( | |
label="Input image", type="pil", height=480 | |
) | |
with gr.Group(): | |
tag_threshold_slider = gr.Slider( | |
label="Tags threshold", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.3, | |
step=0.01, | |
) | |
start_btn = gr.Button(value="Start", variant="primary") | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=[input_img], | |
cache_examples=False, | |
) | |
with gr.Column(): | |
output_tags = gr.Text(label="Output text", interactive=False) | |
output_label = gr.Label(label="Output tags") | |
start_btn.click( | |
fn=predict_tags, | |
inputs=[input_img, tag_threshold_slider], | |
outputs=[output_tags, output_label], | |
) | |
ui.launch( | |
debug=True, | |
# share=True | |
) | |
if __name__ == "__main__": | |
demo() | |