Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,611 Bytes
c9dfb9e 76ac8f3 b36fcf5 76ac8f3 c1c88a4 76ac8f3 c9dfb9e aa2fbfe c9dfb9e 76ac8f3 c9dfb9e aa2fbfe c9dfb9e aa2fbfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import csv
import os
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import spaces
# 画像のサイズ設定
IMAGE_SIZE = 448
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # BGRからRGBへ変換
# 画像を正方形にするためのパディングを追加
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
# サイズに合わせた補間方法を選択
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
@spaces.GPU
def main(image_path, model_id):
print("Hugging Faceからモデルをダウンロード中")
onnx_path = hf_hub_download(model_id, "model.onnx")
csv_path = hf_hub_download(model_id, "selected_tags.csv")
# ONNXモデルとCSVファイルの読み込み
image = Image.open(image_path)
image = image.convert("RGB") if image.mode != "RGB" else image
image = preprocess_image(image)
img = np.array([image])
ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
next(reader) # ヘッダーをスキップ
rows = list(reader)
rating_tags = [row[1] for row in rows if row[2] == "9"]
character_tags = [row[1] for row in rows if row[2] == "4"]
general_tags = [row[1] for row in rows if row[2] == "0"]
# タグと評価
NSFW_flag, IP_flag, tag_text = evaluate_tags(prob, rating_tags, character_tags, general_tags)
return NSFW_flag, IP_flag, tag_text
def evaluate_tags(prob, rating_tags, character_tags, general_tags):
thresh = 0.35
# NSFW/SFW判定
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
max_sfw_score = tag_confidences.get("general", 0)
NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
# 版権キャラクターの可能性を評価
character_tags_with_probs = []
for i, p in enumerate(prob[4:]):
if p >= thresh and i >= len(general_tags):
tag_index = i - len(general_tags)
if tag_index < len(character_tags):
tag_name = character_tags[tag_index]
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
# タグを生成
general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
return NSFW_flag, IP_flag, tag_text
class webui:
def __init__(self):
self.demo = gr.Blocks()
def launch(self):
with self.demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='filepath', label="Analysis Image")
model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3")
output_0 = gr.Textbox(label="NSFW Flag")
output_1 = gr.Textbox(label="IP Flag")
output_2 = gr.Textbox(label="Tags")
submit = gr.Button(value="Start Analysis")
submit.click(
main,
inputs=[input_image, model_id],
outputs=[output_0, output_1, output_2]
)
self.demo.launch(share=True) # 公開リンクを設定
if __name__ == "__main__":
ui = webui()
ui.launch()
|