Spaces:
Running
Running
File size: 3,109 Bytes
8860ed6 e212637 aa50ee2 e212637 b6be0cf 5b33995 bc9c134 e212637 742e6a4 b6be0cf bc9c134 aa50ee2 e212637 aa50ee2 e212637 b6be0cf e212637 742e6a4 e212637 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
from PIL import Image
import numpy as np
import torch
from transformers import (
AutoImageProcessor,
)
import gradio as gr
from modeling_siglip import SiglipForImageClassification
MODEL_NAME = os.environ["MODEL_NAME"]
PROCESSOR_NAME = MODEL_NAME
HF_TOKEN = os.environ["HF_READ_TOKEN"]
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
README_MD = """\
## SigLIP Tagger Test 3
An experimental model for tagging danbooru tags of images using SigLIP.
Model(s):
- [p1atdev/siglip-tagger-test-3](https://huggingface.co/p1atdev/siglip-tagger-test-3)
Example images by NovelAI and niji・journey.
"""
model = SiglipForImageClassification.from_pretrained(MODEL_NAME, token=HF_TOKEN)
processor = AutoImageProcessor.from_pretrained(PROCESSOR_NAME, token=HF_TOKEN)
def compose_text(results: dict[str, float], threshold: float = 0.3):
return ", ".join(
[
key
for key, value in sorted(results.items(), key=lambda x: x[1], reverse=True)
if value > threshold
]
)
@torch.no_grad()
def predict_tags(image: Image.Image, threshold: float):
inputs = processor(image, return_tensors="pt")
logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
logits = np.clip(logits, 0.0, 1.0)
results = {}
for prediction in logits:
for i, prob in enumerate(prediction):
if prob.item() > 0:
results[model.config.id2label[i]] = prob.item()
return compose_text(results, threshold), results
css = """\
.sticky {
position: sticky;
top: 16px;
}
.gradio-container {
overflow: clip;
}
"""
def demo():
with gr.Blocks(css=css) as ui:
gr.Markdown(README_MD)
with gr.Row():
with gr.Column():
with gr.Row(elem_classes="sticky"):
with gr.Column():
input_img = gr.Image(
label="Input image", type="pil", height=480
)
with gr.Group():
tag_threshold_slider = gr.Slider(
label="Tags threshold",
minimum=0.0,
maximum=1.0,
value=0.3,
step=0.01,
)
start_btn = gr.Button(value="Start", variant="primary")
gr.Examples(
examples=EXAMPLES,
inputs=[input_img],
cache_examples=False,
)
with gr.Column():
output_tags = gr.Text(label="Output text", interactive=False)
output_label = gr.Label(label="Output tags")
start_btn.click(
fn=predict_tags,
inputs=[input_img, tag_threshold_slider],
outputs=[output_tags, output_label],
)
ui.launch(
debug=True,
# share=True
)
if __name__ == "__main__":
demo()
|