from os import getenv from pathlib import Path import gradio as gr from PIL import Image from rich.traceback import install as traceback_install from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image from tagger.model import load_model_and_transform, process_heatmap TITLE = "WD Tagger Heatmap" DESCRIPTION = """WD Tagger v3 Heatmap Generator.""" # get HF token HF_TOKEN = getenv("HF_TOKEN", None) # model repo and cache MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3" # get the repo root (or the current working directory if running in ipython) WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() # allowed extensions IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] _ = traceback_install(show_locals=True, locals_max_length=0) # get the example images example_images = sorted( [ str(x.relative_to(WORK_DIR)) for x in WORK_DIR.joinpath("examples").iterdir() if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS ] ) def predict( image: Image.Image, threshold: float = 0.5, ): # join variant for cache key model, transform = load_model_and_transform(MODEL_REPO) # load labels labels: LabelData = load_labels_hf(MODEL_REPO) # preprocess image image = preprocess_image(image, (448, 448)) image = transform(image).unsqueeze(0) # get the model output heatmaps: list[Heatmap] image_labels: ImageLabels heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold) heatmap_images = [(x.image, x.label) for x in heatmaps] return ( heatmap_images, heatmap_grid, image_labels.caption, image_labels.booru, image_labels.rating, image_labels.character, image_labels.general, ) css = """ #use_mcut, #char_mcut { padding-top: var(--scale-3); } #threshold.dimmed { filter: brightness(75%); } """ with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: with gr.Row(equal_height=False): with gr.Column(min_width=720): with gr.Group(): img_input = gr.Image( label="Input", type="pil", image_mode="RGB", sources=["upload", "clipboard"], ) with gr.Group(): with gr.Row(): threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.35, step=0.01, label="Tag Threshold", scale=5, elem_id="threshold", ) with gr.Row(): clear = gr.ClearButton( components=[], variant="secondary", size="lg", ) submit = gr.Button(value="Submit", variant="primary", size="lg") with gr.Column(min_width=720): with gr.Tab(label="Heatmaps"): heatmap_gallery = gr.Gallery(columns=3, show_label=False) with gr.Tab(label="Grid"): heatmap_grid = gr.Image(show_label=False) with gr.Tab(label="Tags"): with gr.Group(): caption = gr.Textbox(label="Caption", show_copy_button=True) tags = gr.Textbox(label="Tags", show_copy_button=True) with gr.Group(): rating = gr.Label(label="Rating") with gr.Group(): character = gr.Label(label="Character") with gr.Group(): general = gr.Label(label="General") with gr.Row(): examples = [[imgpath, 0.35] for imgpath in example_images] examples = gr.Examples( examples=examples, inputs=[img_input, threshold], ) # tell clear button which components to clear clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general]) submit.click( predict, inputs=[img_input, threshold], outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general], api_name="predict", ) if __name__ == "__main__": demo.queue(max_size=10) if getenv("SPACE_ID", None) is not None: demo.launch() else: demo.launch( server_name="0.0.0.0", server_port=7871, debug=True, )