Spaces:
Running
Running
Upload 9 files
Browse files- app.py +279 -0
- artist_predictor.pth +3 -0
- artist_tag_dict.json +0 -0
- character_predictor.pth +3 -0
- character_tag_dict.json +0 -0
- cls_predictor.pth +3 -0
- general_tag_dict.json +0 -0
- implications_list.json +0 -0
- requirements.txt +4 -0
app.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoModel, CLIPImageProcessor
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
TITLE = "Danbooru Tagger"
|
10 |
+
DESCRIPTION = """
|
11 |
+
Macro F1 (General & Character): 0.4790
|
12 |
+
"""
|
13 |
+
|
14 |
+
shortest_edge = 512
|
15 |
+
patch_size = 16
|
16 |
+
|
17 |
+
device = torch.device('cpu')
|
18 |
+
|
19 |
+
image_processor = CLIPImageProcessor.from_pretrained('nvidia/RADIO-H')
|
20 |
+
model = AutoModel.from_pretrained('nvidia/RADIO-H', trust_remote_code=True).to(device)
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
class MLP(nn.Module):
|
24 |
+
def __init__(self, input_size, class_num):
|
25 |
+
super().__init__()
|
26 |
+
self.layers0 = nn.Sequential(
|
27 |
+
nn.Linear(input_size, 1280),
|
28 |
+
nn.LayerNorm(1280),
|
29 |
+
nn.Mish()
|
30 |
+
)
|
31 |
+
self.layers1 = nn.Sequential(
|
32 |
+
nn.Linear(640, class_num),
|
33 |
+
nn.Sigmoid()
|
34 |
+
)
|
35 |
+
self.layers2 = nn.Sequential(
|
36 |
+
nn.Linear(1280, 640),
|
37 |
+
nn.LayerNorm(640),
|
38 |
+
nn.Mish(),
|
39 |
+
nn.Dropout(0.2)
|
40 |
+
)
|
41 |
+
self.layers3 = nn.Sequential(
|
42 |
+
nn.Linear(1280, 640),
|
43 |
+
nn.LayerNorm(640),
|
44 |
+
nn.Mish(),
|
45 |
+
nn.Dropout(0.2)
|
46 |
+
)
|
47 |
+
self.layers4 = nn.Sequential(
|
48 |
+
nn.Linear(1280, 640),
|
49 |
+
nn.LayerNorm(640),
|
50 |
+
nn.Mish(),
|
51 |
+
nn.Dropout(0.2)
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
out = self.layers0(x)
|
56 |
+
out = self.layers2(out) + self.layers3(out) + self.layers4(out)
|
57 |
+
out = self.layers1(out)
|
58 |
+
return out
|
59 |
+
|
60 |
+
with open('general_tag_dict.json', 'r', encoding='utf-8') as f:
|
61 |
+
general_dict = json.load(f)
|
62 |
+
|
63 |
+
with open('character_tag_dict.json', 'r', encoding='utf-8') as f:
|
64 |
+
character_dict = json.load(f)
|
65 |
+
|
66 |
+
with open('artist_tag_dict.json', 'r', encoding='utf-8') as f:
|
67 |
+
artist_dict = json.load(f)
|
68 |
+
|
69 |
+
with open('implications_list.json', 'r', encoding='utf-8') as f:
|
70 |
+
implications_list = json.load(f)
|
71 |
+
|
72 |
+
general_class = 9775
|
73 |
+
mlp_general = MLP(3840, general_class)
|
74 |
+
general_s = torch.load("cls_predictor.pth", map_location=device)
|
75 |
+
mlp_general.load_state_dict(general_s)
|
76 |
+
mlp_general.to(device)
|
77 |
+
mlp_general.eval()
|
78 |
+
|
79 |
+
character_class = 7568
|
80 |
+
mlp_character = MLP(3840, character_class)
|
81 |
+
character_s = torch.load("character_predictor.pth", map_location=device)
|
82 |
+
mlp_character.load_state_dict(character_s)
|
83 |
+
mlp_character.to(device)
|
84 |
+
mlp_character.eval()
|
85 |
+
|
86 |
+
artist_class = 13957
|
87 |
+
mlp_artist = MLP(3840, artist_class)
|
88 |
+
artist_s = torch.load("artist_predictor.pth", map_location=device)
|
89 |
+
mlp_artist.load_state_dict(artist_s)
|
90 |
+
mlp_artist.to(device)
|
91 |
+
mlp_artist.eval()
|
92 |
+
|
93 |
+
def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
|
94 |
+
prediction = prediction.view(class_num)
|
95 |
+
predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
|
96 |
+
general = {}
|
97 |
+
character = {}
|
98 |
+
artist = {}
|
99 |
+
date = {}
|
100 |
+
rating = {}
|
101 |
+
|
102 |
+
for tag, value in tag_dict.items():
|
103 |
+
if value[2] in predicted_ids:
|
104 |
+
tag_value = round(prediction[value[2] - 1].item(), 6)
|
105 |
+
if value[1] == "general" and tag_value >= general_threshold:
|
106 |
+
general[tag] = tag_value
|
107 |
+
elif value[1] == "character" and tag_value >= character_threshold:
|
108 |
+
character[tag] = tag_value
|
109 |
+
elif value[1] == "artist" and tag_value >= artist_threshold:
|
110 |
+
artist[tag] = tag_value
|
111 |
+
elif value[1] == "rating":
|
112 |
+
rating[tag] = tag_value
|
113 |
+
elif value[1] == "date":
|
114 |
+
date[tag] = tag_value
|
115 |
+
|
116 |
+
general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
|
117 |
+
character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
|
118 |
+
|
119 |
+
if date:
|
120 |
+
date = {max(date, key=date.get): date[max(date, key=date.get)]}
|
121 |
+
if rating:
|
122 |
+
rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
|
123 |
+
if artist:
|
124 |
+
artist = {max(artist, key=artist.get): artist[max(artist, key=artist.get)]}
|
125 |
+
|
126 |
+
return general, character, artist, date, rating
|
127 |
+
|
128 |
+
|
129 |
+
def process_image(image, general_threshold, character_threshold, artist_threshold):
|
130 |
+
try:
|
131 |
+
image = image.convert('RGBA')
|
132 |
+
background = Image.new('RGBA', image.size, (255, 255, 255, 255))
|
133 |
+
image = Image.alpha_composite(background, image).convert('RGB')
|
134 |
+
|
135 |
+
width, height = image.size
|
136 |
+
|
137 |
+
if width < height:
|
138 |
+
height = int((shortest_edge / width) * height)
|
139 |
+
width = shortest_edge
|
140 |
+
else:
|
141 |
+
width = int((shortest_edge / height) * width)
|
142 |
+
height = shortest_edge
|
143 |
+
|
144 |
+
height = int(round(height / patch_size) * patch_size)
|
145 |
+
width = int(round(width / patch_size) * patch_size)
|
146 |
+
|
147 |
+
height = max(height, patch_size)
|
148 |
+
width = max(width, patch_size)
|
149 |
+
|
150 |
+
image = image.resize((width, height), Image.LANCZOS)
|
151 |
+
pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values
|
152 |
+
pixel_values = pixel_values.to(device).to(torch.bfloat16)
|
153 |
+
|
154 |
+
except (OSError, IOError) as e:
|
155 |
+
print(f"Error opening image: {e}")
|
156 |
+
return
|
157 |
+
|
158 |
+
with torch.no_grad():
|
159 |
+
summary, features = model(pixel_values)
|
160 |
+
outputs = summary.to(torch.float32)
|
161 |
+
|
162 |
+
general_prediction = mlp_general(outputs)
|
163 |
+
general_ = prediction_to_tag(general_prediction, general_dict, general_class, general_threshold, character_threshold, artist_threshold)
|
164 |
+
general_tags = general_[0]
|
165 |
+
rating = general_[4]
|
166 |
+
|
167 |
+
character_prediction = mlp_character(outputs)
|
168 |
+
character_ = prediction_to_tag(character_prediction, character_dict, character_class, general_threshold, character_threshold, artist_threshold)
|
169 |
+
character_tags = character_[1]
|
170 |
+
|
171 |
+
artist_prediction = mlp_artist(outputs)
|
172 |
+
artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class, general_threshold, character_threshold, artist_threshold)
|
173 |
+
artist_tags = artist_[2]
|
174 |
+
date = artist_[3]
|
175 |
+
|
176 |
+
combined_tags = {**artist_tags, **character_tags, **general_tags}
|
177 |
+
|
178 |
+
tags_list = [tag for tag in combined_tags]
|
179 |
+
remove_list = []
|
180 |
+
for tag in tags_list:
|
181 |
+
if tag in implications_list:
|
182 |
+
for implication in implications_list[tag]:
|
183 |
+
remove_list.append(implication)
|
184 |
+
tags_list = [tag for tag in tags_list if tag not in remove_list]
|
185 |
+
|
186 |
+
tags_str = ", ".join(tags_list)
|
187 |
+
|
188 |
+
return tags_str, artist_tags, character_tags, general_tags, rating, date
|
189 |
+
|
190 |
+
def parse_args() -> argparse.Namespace:
|
191 |
+
parser = argparse.ArgumentParser()
|
192 |
+
parser.add_argument("--slider-step", type=float, default=0.01)
|
193 |
+
parser.add_argument("--general-threshold", type=float, default=0.5)
|
194 |
+
parser.add_argument("--character-threshold", type=float, default=0.76)
|
195 |
+
parser.add_argument("--artist-threshold", type=float, default=0.68)
|
196 |
+
return parser.parse_args()
|
197 |
+
|
198 |
+
def main():
|
199 |
+
args = parse_args()
|
200 |
+
|
201 |
+
with gr.Blocks(title=TITLE) as demo:
|
202 |
+
with gr.Column():
|
203 |
+
gr.Markdown(
|
204 |
+
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
205 |
+
)
|
206 |
+
gr.Markdown(value=DESCRIPTION)
|
207 |
+
with gr.Row():
|
208 |
+
with gr.Column(variant="panel"):
|
209 |
+
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
|
210 |
+
with gr.Row():
|
211 |
+
general_threshold = gr.Slider(
|
212 |
+
0,
|
213 |
+
1,
|
214 |
+
step=args.slider_step,
|
215 |
+
value=args.general_threshold,
|
216 |
+
label="General Threshold",
|
217 |
+
scale=3,
|
218 |
+
)
|
219 |
+
with gr.Row():
|
220 |
+
character_threshold = gr.Slider(
|
221 |
+
0,
|
222 |
+
1,
|
223 |
+
step=args.slider_step,
|
224 |
+
value=args.character_threshold,
|
225 |
+
label="Character Threshold",
|
226 |
+
scale=3,
|
227 |
+
)
|
228 |
+
with gr.Row():
|
229 |
+
artist_threshold = gr.Slider(
|
230 |
+
0,
|
231 |
+
1,
|
232 |
+
step=args.slider_step,
|
233 |
+
value=args.artist_threshold,
|
234 |
+
label="Artist Threshold",
|
235 |
+
scale=3,
|
236 |
+
)
|
237 |
+
with gr.Row():
|
238 |
+
clear = gr.ClearButton(
|
239 |
+
components=[
|
240 |
+
image,
|
241 |
+
],
|
242 |
+
variant="secondary",
|
243 |
+
size="lg",
|
244 |
+
)
|
245 |
+
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
246 |
+
with gr.Column(variant="panel"):
|
247 |
+
tags_str = gr.Textbox(label="Output")
|
248 |
+
artist_tags = gr.Label(label="Artist")
|
249 |
+
character_tags = gr.Label(label="Characters")
|
250 |
+
rating = gr.Label(label="Rating")
|
251 |
+
date = gr.Label(label="Year")
|
252 |
+
general_tags = gr.Label(label="General")
|
253 |
+
clear.add(
|
254 |
+
[
|
255 |
+
tags_str,
|
256 |
+
artist_tags,
|
257 |
+
general_tags,
|
258 |
+
character_tags,
|
259 |
+
rating,
|
260 |
+
date
|
261 |
+
]
|
262 |
+
)
|
263 |
+
|
264 |
+
submit.click(
|
265 |
+
process_image,
|
266 |
+
inputs=[
|
267 |
+
image,
|
268 |
+
general_threshold,
|
269 |
+
character_threshold,
|
270 |
+
artist_threshold
|
271 |
+
],
|
272 |
+
outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
|
273 |
+
)
|
274 |
+
|
275 |
+
demo.queue(max_size=10)
|
276 |
+
demo.launch()
|
277 |
+
|
278 |
+
if __name__ == "__main__":
|
279 |
+
main()
|
artist_predictor.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:781a78126d3f1e818ad699344318edd17cb309290d71c6499b170ce15eb58cea
|
3 |
+
size 65322132
|
artist_tag_dict.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
character_predictor.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cb43260537161ac6896568329c130d7197c4c226b1f57a0ee7e7eeffd2841e6
|
3 |
+
size 48940690
|
character_tag_dict.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
cls_predictor.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b0bb58f320b941f20d9c3b9e3af4dc87780d9cf3f9d50be7a72b684028cd7763
|
3 |
+
size 54599508
|
general_tag_dict.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
implications_list.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
Pillow
|
4 |
+
gradio
|