Johnny-Z commited on
Commit
c08a487
·
verified ·
1 Parent(s): 3f1281a

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoModel, CLIPImageProcessor
3
+ import torch
4
+ import json
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ import gradio as gr
8
+
9
+ TITLE = "Danbooru Tagger"
10
+ DESCRIPTION = """
11
+ Macro F1 (General & Character): 0.4790
12
+ """
13
+
14
+ shortest_edge = 512
15
+ patch_size = 16
16
+
17
+ device = torch.device('cpu')
18
+
19
+ image_processor = CLIPImageProcessor.from_pretrained('nvidia/RADIO-H')
20
+ model = AutoModel.from_pretrained('nvidia/RADIO-H', trust_remote_code=True).to(device)
21
+ model.eval()
22
+
23
+ class MLP(nn.Module):
24
+ def __init__(self, input_size, class_num):
25
+ super().__init__()
26
+ self.layers0 = nn.Sequential(
27
+ nn.Linear(input_size, 1280),
28
+ nn.LayerNorm(1280),
29
+ nn.Mish()
30
+ )
31
+ self.layers1 = nn.Sequential(
32
+ nn.Linear(640, class_num),
33
+ nn.Sigmoid()
34
+ )
35
+ self.layers2 = nn.Sequential(
36
+ nn.Linear(1280, 640),
37
+ nn.LayerNorm(640),
38
+ nn.Mish(),
39
+ nn.Dropout(0.2)
40
+ )
41
+ self.layers3 = nn.Sequential(
42
+ nn.Linear(1280, 640),
43
+ nn.LayerNorm(640),
44
+ nn.Mish(),
45
+ nn.Dropout(0.2)
46
+ )
47
+ self.layers4 = nn.Sequential(
48
+ nn.Linear(1280, 640),
49
+ nn.LayerNorm(640),
50
+ nn.Mish(),
51
+ nn.Dropout(0.2)
52
+ )
53
+
54
+ def forward(self, x):
55
+ out = self.layers0(x)
56
+ out = self.layers2(out) + self.layers3(out) + self.layers4(out)
57
+ out = self.layers1(out)
58
+ return out
59
+
60
+ with open('general_tag_dict.json', 'r', encoding='utf-8') as f:
61
+ general_dict = json.load(f)
62
+
63
+ with open('character_tag_dict.json', 'r', encoding='utf-8') as f:
64
+ character_dict = json.load(f)
65
+
66
+ with open('artist_tag_dict.json', 'r', encoding='utf-8') as f:
67
+ artist_dict = json.load(f)
68
+
69
+ with open('implications_list.json', 'r', encoding='utf-8') as f:
70
+ implications_list = json.load(f)
71
+
72
+ general_class = 9775
73
+ mlp_general = MLP(3840, general_class)
74
+ general_s = torch.load("cls_predictor.pth", map_location=device)
75
+ mlp_general.load_state_dict(general_s)
76
+ mlp_general.to(device)
77
+ mlp_general.eval()
78
+
79
+ character_class = 7568
80
+ mlp_character = MLP(3840, character_class)
81
+ character_s = torch.load("character_predictor.pth", map_location=device)
82
+ mlp_character.load_state_dict(character_s)
83
+ mlp_character.to(device)
84
+ mlp_character.eval()
85
+
86
+ artist_class = 13957
87
+ mlp_artist = MLP(3840, artist_class)
88
+ artist_s = torch.load("artist_predictor.pth", map_location=device)
89
+ mlp_artist.load_state_dict(artist_s)
90
+ mlp_artist.to(device)
91
+ mlp_artist.eval()
92
+
93
+ def prediction_to_tag(prediction, tag_dict, class_num, general_threshold, character_threshold, artist_threshold):
94
+ prediction = prediction.view(class_num)
95
+ predicted_ids = (prediction>= 0.2).nonzero(as_tuple=True)[0].cpu().numpy() + 1
96
+ general = {}
97
+ character = {}
98
+ artist = {}
99
+ date = {}
100
+ rating = {}
101
+
102
+ for tag, value in tag_dict.items():
103
+ if value[2] in predicted_ids:
104
+ tag_value = round(prediction[value[2] - 1].item(), 6)
105
+ if value[1] == "general" and tag_value >= general_threshold:
106
+ general[tag] = tag_value
107
+ elif value[1] == "character" and tag_value >= character_threshold:
108
+ character[tag] = tag_value
109
+ elif value[1] == "artist" and tag_value >= artist_threshold:
110
+ artist[tag] = tag_value
111
+ elif value[1] == "rating":
112
+ rating[tag] = tag_value
113
+ elif value[1] == "date":
114
+ date[tag] = tag_value
115
+
116
+ general = dict(sorted(general.items(), key=lambda item: item[1], reverse=True))
117
+ character = dict(sorted(character.items(), key=lambda item: item[1], reverse=True))
118
+
119
+ if date:
120
+ date = {max(date, key=date.get): date[max(date, key=date.get)]}
121
+ if rating:
122
+ rating = {max(rating, key=rating.get): rating[max(rating, key=rating.get)]}
123
+ if artist:
124
+ artist = {max(artist, key=artist.get): artist[max(artist, key=artist.get)]}
125
+
126
+ return general, character, artist, date, rating
127
+
128
+
129
+ def process_image(image, general_threshold, character_threshold, artist_threshold):
130
+ try:
131
+ image = image.convert('RGBA')
132
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
133
+ image = Image.alpha_composite(background, image).convert('RGB')
134
+
135
+ width, height = image.size
136
+
137
+ if width < height:
138
+ height = int((shortest_edge / width) * height)
139
+ width = shortest_edge
140
+ else:
141
+ width = int((shortest_edge / height) * width)
142
+ height = shortest_edge
143
+
144
+ height = int(round(height / patch_size) * patch_size)
145
+ width = int(round(width / patch_size) * patch_size)
146
+
147
+ height = max(height, patch_size)
148
+ width = max(width, patch_size)
149
+
150
+ image = image.resize((width, height), Image.LANCZOS)
151
+ pixel_values = image_processor(images=image, return_tensors='pt', do_resize=True).pixel_values
152
+ pixel_values = pixel_values.to(device).to(torch.bfloat16)
153
+
154
+ except (OSError, IOError) as e:
155
+ print(f"Error opening image: {e}")
156
+ return
157
+
158
+ with torch.no_grad():
159
+ summary, features = model(pixel_values)
160
+ outputs = summary.to(torch.float32)
161
+
162
+ general_prediction = mlp_general(outputs)
163
+ general_ = prediction_to_tag(general_prediction, general_dict, general_class, general_threshold, character_threshold, artist_threshold)
164
+ general_tags = general_[0]
165
+ rating = general_[4]
166
+
167
+ character_prediction = mlp_character(outputs)
168
+ character_ = prediction_to_tag(character_prediction, character_dict, character_class, general_threshold, character_threshold, artist_threshold)
169
+ character_tags = character_[1]
170
+
171
+ artist_prediction = mlp_artist(outputs)
172
+ artist_ = prediction_to_tag(artist_prediction, artist_dict, artist_class, general_threshold, character_threshold, artist_threshold)
173
+ artist_tags = artist_[2]
174
+ date = artist_[3]
175
+
176
+ combined_tags = {**artist_tags, **character_tags, **general_tags}
177
+
178
+ tags_list = [tag for tag in combined_tags]
179
+ remove_list = []
180
+ for tag in tags_list:
181
+ if tag in implications_list:
182
+ for implication in implications_list[tag]:
183
+ remove_list.append(implication)
184
+ tags_list = [tag for tag in tags_list if tag not in remove_list]
185
+
186
+ tags_str = ", ".join(tags_list)
187
+
188
+ return tags_str, artist_tags, character_tags, general_tags, rating, date
189
+
190
+ def parse_args() -> argparse.Namespace:
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument("--slider-step", type=float, default=0.01)
193
+ parser.add_argument("--general-threshold", type=float, default=0.5)
194
+ parser.add_argument("--character-threshold", type=float, default=0.76)
195
+ parser.add_argument("--artist-threshold", type=float, default=0.68)
196
+ return parser.parse_args()
197
+
198
+ def main():
199
+ args = parse_args()
200
+
201
+ with gr.Blocks(title=TITLE) as demo:
202
+ with gr.Column():
203
+ gr.Markdown(
204
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
205
+ )
206
+ gr.Markdown(value=DESCRIPTION)
207
+ with gr.Row():
208
+ with gr.Column(variant="panel"):
209
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
210
+ with gr.Row():
211
+ general_threshold = gr.Slider(
212
+ 0,
213
+ 1,
214
+ step=args.slider_step,
215
+ value=args.general_threshold,
216
+ label="General Threshold",
217
+ scale=3,
218
+ )
219
+ with gr.Row():
220
+ character_threshold = gr.Slider(
221
+ 0,
222
+ 1,
223
+ step=args.slider_step,
224
+ value=args.character_threshold,
225
+ label="Character Threshold",
226
+ scale=3,
227
+ )
228
+ with gr.Row():
229
+ artist_threshold = gr.Slider(
230
+ 0,
231
+ 1,
232
+ step=args.slider_step,
233
+ value=args.artist_threshold,
234
+ label="Artist Threshold",
235
+ scale=3,
236
+ )
237
+ with gr.Row():
238
+ clear = gr.ClearButton(
239
+ components=[
240
+ image,
241
+ ],
242
+ variant="secondary",
243
+ size="lg",
244
+ )
245
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
246
+ with gr.Column(variant="panel"):
247
+ tags_str = gr.Textbox(label="Output")
248
+ artist_tags = gr.Label(label="Artist")
249
+ character_tags = gr.Label(label="Characters")
250
+ rating = gr.Label(label="Rating")
251
+ date = gr.Label(label="Year")
252
+ general_tags = gr.Label(label="General")
253
+ clear.add(
254
+ [
255
+ tags_str,
256
+ artist_tags,
257
+ general_tags,
258
+ character_tags,
259
+ rating,
260
+ date
261
+ ]
262
+ )
263
+
264
+ submit.click(
265
+ process_image,
266
+ inputs=[
267
+ image,
268
+ general_threshold,
269
+ character_threshold,
270
+ artist_threshold
271
+ ],
272
+ outputs=[tags_str, artist_tags, character_tags, general_tags, rating, date],
273
+ )
274
+
275
+ demo.queue(max_size=10)
276
+ demo.launch()
277
+
278
+ if __name__ == "__main__":
279
+ main()
artist_predictor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:781a78126d3f1e818ad699344318edd17cb309290d71c6499b170ce15eb58cea
3
+ size 65322132
artist_tag_dict.json ADDED
The diff for this file is too large to render. See raw diff
 
character_predictor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cb43260537161ac6896568329c130d7197c4c226b1f57a0ee7e7eeffd2841e6
3
+ size 48940690
character_tag_dict.json ADDED
The diff for this file is too large to render. See raw diff
 
cls_predictor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0bb58f320b941f20d9c3b9e3af4dc87780d9cf3f9d50be7a72b684028cd7763
3
+ size 54599508
general_tag_dict.json ADDED
The diff for this file is too large to render. See raw diff
 
implications_list.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio