SmilingWolf's picture
Add support for model selection
f6dbb10
raw
history blame
6.01 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import html
import os
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
import piexif
import piexif.helper
import PIL.Image
from Utils import dbimutils
TITLE = "WaifuDiffusion v1.4 Tags"
DESCRIPTION = """
Demo for [SmilingWolf/wd-v1-4-vit-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) and [SmilingWolf/wd-v1-4-convnext-tagger](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger) with "ready to copy" prompt and a prompt analyzer.
Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
"""
HF_TOKEN = os.environ["HF_TOKEN"]
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--score-slider-step", type=float, default=0.05)
parser.add_argument("--score-threshold", type=float, default=0.35)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
path = huggingface_hub.hf_hub_download(
model_repo, model_filename, use_auth_token=HF_TOKEN
)
model = rt.InferenceSession(path)
return model
def load_labels() -> list[str]:
path = huggingface_hub.hf_hub_download(
VIT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
)
df = pd.read_csv(path)["name"].tolist()
return df
def plaintext_to_html(text):
text = (
"<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
)
return text
def predict(
image: PIL.Image.Image,
selected_model: str,
score_threshold: float,
models: dict,
labels: list[str],
):
rawimage = image
model = models[selected_model]
_, height, width, _ = model.get_inputs()[0].shape
# Alpha to white
image = image.convert("RGBA")
new_image = PIL.Image.new("RGBA", image.size, "WHITE")
new_image.paste(image, mask=image)
image = new_image.convert("RGB")
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
input_name = model.get_inputs()[0].name
label_name = model.get_outputs()[0].name
probs = model.run([label_name], {input_name: image})[0]
labels = list(zip(labels, probs[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = labels[:4]
rating = dict(ratings_names)
# Everything else is tags: pick any where prediction confidence > threshold
tags_names = labels[4:]
res = [x for x in tags_names if x[1] > score_threshold]
res = dict(res)
b = dict(sorted(res.items(), key=lambda item: item[1], reverse=True))
a = (
", ".join(list(b.keys()))
.replace("_", " ")
.replace("(", "\(")
.replace(")", "\)")
)
c = ", ".join(list(b.keys()))
items = rawimage.info
geninfo = ""
if "exif" in rawimage.info:
exif = piexif.load(rawimage.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode("utf8", errors="ignore")
items["exif comment"] = exif_comment
geninfo = exif_comment
for field in [
"jfif",
"jfif_version",
"jfif_unit",
"jfif_density",
"dpi",
"exif",
"loop",
"background",
"timestamp",
"duration",
]:
items.pop(field, None)
geninfo = items.get("parameters", geninfo)
info = f"""
<p><h4>PNG Info</h4></p>
"""
for key, text in items.items():
info += (
f"""
<div>
<p><b>{plaintext_to_html(str(key))}</b></p>
<p>{plaintext_to_html(str(text))}</p>
</div>
""".strip()
+ "\n"
)
if len(info) == 0:
message = "Nothing found in the image."
info = f"<div><p>{message}<p></div>"
return (a, c, rating, res, info)
def main():
args = parse_args()
vit_model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
conv_model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
labels = load_labels()
models = {"ViT": vit_model, "ConvNext": conv_model}
func = functools.partial(predict, models=models, labels=labels)
gr.Interface(
fn=func,
inputs=[
gr.Image(type="pil", label="Input"),
gr.Radio(["ViT", "ConvNext"], label="Model"),
gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.score_threshold,
label="Score Threshold",
),
],
outputs=[
gr.Textbox(label="Output (string)"),
gr.Textbox(label="Output (raw string)"),
gr.Label(label="Rating"),
gr.Label(label="Output (label)"),
gr.HTML(),
],
examples=[["power.jpg", "ViT", 0.5]],
title=TITLE,
description=DESCRIPTION,
allow_flagging="never",
).launch(
enable_queue=True,
share=args.share,
)
if __name__ == "__main__":
main()