File size: 4,611 Bytes
c9dfb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ac8f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b36fcf5
76ac8f3
 
c1c88a4
76ac8f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9dfb9e
 
 
 
 
 
 
 
 
aa2fbfe
c9dfb9e
 
 
 
 
 
76ac8f3
c9dfb9e
 
 
 
aa2fbfe
c9dfb9e
 
 
aa2fbfe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import csv
import os
from pathlib import Path

import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import spaces

# 画像のサイズ設定
IMAGE_SIZE = 448

def preprocess_image(image):
    image = np.array(image)
    image = image[:, :, ::-1]  # BGRからRGBへ変換

    # 画像を正方形にするためのパディングを追加
    size = max(image.shape[0:2])
    pad_x = size - image.shape[1]
    pad_y = size - image.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2
    image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    # サイズに合わせた補間方法を選択
    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
    image = image.astype(np.float32)
    return image

@spaces.GPU
def main(image_path, model_id):
    print("Hugging Faceからモデルをダウンロード中")
    onnx_path = hf_hub_download(model_id, "model.onnx")
    csv_path = hf_hub_download(model_id, "selected_tags.csv")

    # ONNXモデルとCSVファイルの読み込み
    image = Image.open(image_path)
    image = image.convert("RGB") if image.mode != "RGB" else image
    image = preprocess_image(image)
    img = np.array([image])

    ort_sess = ort.InferenceSession(onnx_path)  # セッションの生成をここで行う
    prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]

    with open(csv_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # ヘッダーをスキップ
        rows = list(reader)

    rating_tags = [row[1] for row in rows if row[2] == "9"]
    character_tags = [row[1] for row in rows if row[2] == "4"]
    general_tags = [row[1] for row in rows if row[2] == "0"]

    # タグと評価
    NSFW_flag, IP_flag, tag_text = evaluate_tags(prob, rating_tags, character_tags, general_tags)
    return NSFW_flag, IP_flag, tag_text

def evaluate_tags(prob, rating_tags, character_tags, general_tags):
    thresh = 0.35
    # NSFW/SFW判定
    tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
    max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
    max_sfw_score = tag_confidences.get("general", 0)
    NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"

    # 版権キャラクターの可能性を評価
    character_tags_with_probs = []
    for i, p in enumerate(prob[4:]):
        if p >= thresh and i >= len(general_tags):
            tag_index = i - len(general_tags)
            if tag_index < len(character_tags):
                tag_name = character_tags[tag_index]
                prob_percent = round(p * 100, 2)  # 確率をパーセンテージに変換
                character_tags_with_probs.append((tag_name, f"{prob_percent}%"))

    IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"

    # タグを生成
    general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
    character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
    tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
    return NSFW_flag, IP_flag, tag_text






class webui:
    def __init__(self):
        self.demo = gr.Blocks()

    def launch(self):
        with self.demo:
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(type='filepath', label="Analysis Image")
                    model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3")
                    output_0 = gr.Textbox(label="NSFW Flag")
                    output_1 = gr.Textbox(label="IP Flag")
                    output_2 = gr.Textbox(label="Tags")
                    submit = gr.Button(value="Start Analysis")
                    
                    submit.click(
                        main, 
                        inputs=[input_image, model_id], 
                        outputs=[output_0, output_1, output_2]
                    )

        self.demo.launch(share=True)  # 公開リンクを設定

if __name__ == "__main__":
    ui = webui()
    ui.launch()