Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import csv | |
import os | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import spaces | |
# 画像のサイズ設定 | |
IMAGE_SIZE = 448 | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # BGRからRGBへ変換 | |
# 画像を正方形にするためのパディングを追加 | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
# サイズに合わせた補間方法を選択 | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def main(image_path, model_id): | |
print("Hugging Faceからモデルをダウンロード中") | |
onnx_path = hf_hub_download(model_id, "model.onnx") | |
csv_path = hf_hub_download(model_id, "selected_tags.csv") | |
# ONNXモデルとCSVファイルの読み込み | |
image = Image.open(image_path) | |
image = image.convert("RGB") if image.mode != "RGB" else image | |
image = preprocess_image(image) | |
img = np.array([image]) | |
ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う | |
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
next(reader) # ヘッダーをスキップ | |
rows = list(reader) | |
rating_tags = [row[1] for row in rows if row[2] == "9"] | |
character_tags = [row[1] for row in rows if row[2] == "4"] | |
general_tags = [row[1] for row in rows if row[2] == "0"] | |
# タグと評価 | |
NSFW_flag, IP_flag, tag_text = evaluate_tags(prob, rating_tags, character_tags, general_tags) | |
return NSFW_flag, IP_flag, tag_text | |
def evaluate_tags(prob, rating_tags, character_tags, general_tags): | |
thresh = 0.35 | |
# NSFW/SFW判定 | |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)} | |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0)) | |
max_sfw_score = tag_confidences.get("general", 0) | |
NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです" | |
# 版権キャラクターの可能性を評価 | |
character_tags_with_probs = [] | |
for i, p in enumerate(prob[4:]): | |
if p >= thresh and i >= len(general_tags): | |
tag_index = i - len(general_tags) | |
if tag_index < len(character_tags): | |
tag_name = character_tags[tag_index] | |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換 | |
character_tags_with_probs.append((tag_name, f"{prob_percent}%")) | |
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます" | |
# タグを生成 | |
general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh]) | |
character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh]) | |
tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text | |
return NSFW_flag, IP_flag, tag_text | |
class webui: | |
def __init__(self): | |
self.demo = gr.Blocks() | |
def launch(self): | |
with self.demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='filepath', label="Analysis Image") | |
model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3") | |
output_0 = gr.Textbox(label="NSFW Flag") | |
output_1 = gr.Textbox(label="IP Flag") | |
output_2 = gr.Textbox(label="Tags") | |
submit = gr.Button(value="Start Analysis") | |
submit.click( | |
main, | |
inputs=[input_image, model_id], | |
outputs=[output_0, output_1, output_2] | |
) | |
self.demo.launch(share=True) # 公開リンクを設定 | |
if __name__ == "__main__": | |
ui = webui() | |
ui.launch() | |