import numpy as np import torch from transformers import ( AutoProcessor, ) from PIL import Image import gradio as gr from modeling_siglip import SiglipForImageClassification MODEL_NAME = "p1atdev/siglip-tagger-test-3" PROCESSOR_NAME = "google/siglip-so400m-patch14-384" model = SiglipForImageClassification.from_pretrained( MODEL_NAME, ) # model = torch.compile(model) processor = AutoProcessor.from_pretrained(PROCESSOR_NAME) def compose_text(results: dict[str, float], threshold: float = 0.3): return ", ".join( [ key for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True) if value > threshold ] ) @torch.no_grad() def predict_tags(image: Image.Image, threshold: float): inputs = processor(images=image, return_tensors="pt") logits = model(**, model.dtype)).logits.detach().cpu() logits = np.clip(logits, 0.0, 1.0) results = {} for prediction in logits: for i, prob in enumerate(prediction): if prob.item() > 0: results[model.config.id2label[i]] = prob.item() return compose_text(results, threshold), results css = """\ .sticky { position: sticky; top: 16px; } .gradio-container { overflow: clip; } """ def demo(): with gr.Blocks(css=css) as ui: gr.Markdown( """\ ## SigLIP Tagger Test 3 An experimental model for tagging danbooru tags of images using SigLIP. Models: - (soon) Example images by NovelAI and niji惻journey. """ ) with gr.Row(): with gr.Column(): with gr.Row(elem_classes="sticky"): with gr.Column(): input_img = gr.Image( label="Input image", type="pil", height=480 ) with gr.Group(): tag_threshold_slider = gr.Slider( label="Tags threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, ) start_btn = gr.Button(value="Start", variant="primary") gr.Examples( examples=[["./sample.jpg"], ["./sample2.webp"]], inputs=[input_img], cache_examples=False, ) with gr.Column(): output_tags = gr.Text(label="Output text", interactive=False) output_label = gr.Label(label="Output tags") fn=predict_tags, inputs=[input_img, tag_threshold_slider], outputs=[output_tags, output_label], ) ui.launch( debug=True, # share=True ) if __name__ == "__main__": demo()