Kotomiya07 commited on
Commit
98ab92d
1 Parent(s): ed8533e

Add application file

Browse files
__pycache__/config.cpython-310.pyc CHANGED
Binary files a/__pycache__/config.cpython-310.pyc and b/__pycache__/config.cpython-310.pyc differ
 
__pycache__/tags.cpython-310.pyc ADDED
Binary file (9.11 kB). View file
 
__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/__pycache__/utils.cpython-310.pyc and b/__pycache__/utils.cpython-310.pyc differ
 
app.py CHANGED
@@ -11,6 +11,7 @@ from PIL import Image, PngImagePlugin
11
  from datetime import datetime
12
  from diffusers.models import AutoencoderKL
13
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
@@ -38,10 +39,10 @@ torch.backends.cudnn.benchmark = False
38
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39
 
40
 
41
- def load_pipeline(model_name):
42
  vae = AutoencoderKL.from_pretrained(
43
- "madebyollin/sdxl-vae-fp16-fix",
44
- torch_dtype=torch.float16,
45
  )
46
  pipeline = (
47
  StableDiffusionXLPipeline.from_single_file
@@ -52,7 +53,7 @@ def load_pipeline(model_name):
52
  pipe = pipeline(
53
  model_name,
54
  vae=vae,
55
- torch_dtype=torch.float16,
56
  custom_pipeline="lpw_stable_diffusion_xl",
57
  use_safetensors=True,
58
  add_watermarker=False,
@@ -195,7 +196,9 @@ if torch.cuda.is_available():
195
  pipe = load_pipeline(MODEL)
196
  logger.info("Loaded on Device!")
197
  else:
198
- pipe = None
 
 
199
 
200
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
201
  quality_prompt = {
@@ -204,6 +207,14 @@ quality_prompt = {
204
 
205
  wildcard_files = utils.load_wildcard_files("wildcard")
206
 
 
 
 
 
 
 
 
 
207
  with gr.Blocks(css="style.css", theme="NoCrypt/miku@1.2.1") as demo:
208
  title = gr.HTML(
209
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
@@ -321,6 +332,305 @@ with gr.Blocks(css="style.css", theme="NoCrypt/miku@1.2.1") as demo:
321
  step=1,
322
  value=28,
323
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  with gr.Column(scale=3):
325
  with gr.Blocks():
326
  run_button = gr.Button("Generate", variant="primary")
 
11
  from datetime import datetime
12
  from diffusers.models import AutoencoderKL
13
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
14
+ import tags
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
 
42
+ def load_pipeline(model_name, vae_model="madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16):
43
  vae = AutoencoderKL.from_pretrained(
44
+ vae_model,
45
+ torch_dtype=torch_dtype,
46
  )
47
  pipeline = (
48
  StableDiffusionXLPipeline.from_single_file
 
53
  pipe = pipeline(
54
  model_name,
55
  vae=vae,
56
+ torch_dtype=torch_dtype,
57
  custom_pipeline="lpw_stable_diffusion_xl",
58
  use_safetensors=True,
59
  add_watermarker=False,
 
196
  pipe = load_pipeline(MODEL)
197
  logger.info("Loaded on Device!")
198
  else:
199
+ logger.info("CPU MODE")
200
+ pipe = load_pipeline(MODEL, vae_model="stabilityai/sdxl-vae", torch_dtype=torch.float32)
201
+ logger.info("Loaded on Device!")
202
 
203
  styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in config.style_list}
204
  quality_prompt = {
 
207
 
208
  wildcard_files = utils.load_wildcard_files("wildcard")
209
 
210
+ COPY_ACTION_JS = """\
211
+ (inputs, _outputs) => {
212
+ // inputs is the string value of the input_text
213
+ if (inputs.trim() !== "") {
214
+ navigator.clipboard.writeText(inputs);
215
+ }
216
+ }"""
217
+
218
  with gr.Blocks(css="style.css", theme="NoCrypt/miku@1.2.1") as demo:
219
  title = gr.HTML(
220
  f"""<h1><span>{DESCRIPTION}</span></h1>""",
 
332
  step=1,
333
  value=28,
334
  )
335
+
336
+ # danbooru-tags-upsampler
337
+ with gr.Tab("tags"):
338
+ with gr.Row():
339
+ with gr.Column():
340
+
341
+ # with gr.Group(
342
+ # visible=False,
343
+ # ):
344
+ # model_backend_radio = gr.Radio(
345
+ # label="Model backend",
346
+ # choices=list(MODEL_BACKEND_MAP.keys()),
347
+ # value="Default",
348
+ # interactive=True,
349
+ # )
350
+
351
+ with gr.Group():
352
+ rating_dropdown = gr.Dropdown(
353
+ label="Rating",
354
+ choices=[
355
+ "general",
356
+ "sensitive",
357
+ "questionable",
358
+ "explicit",
359
+ ],
360
+ value="general",
361
+ )
362
+
363
+ with gr.Group():
364
+ copyright_tags_mode_dropdown = gr.Dropdown(
365
+ label="Copyright tags mode",
366
+ choices=[
367
+ "None",
368
+ "Original",
369
+ # "Auto", # TODO: implement these modes
370
+ # "Random",
371
+ "Custom",
372
+ ],
373
+ value="None",
374
+ interactive=True,
375
+ )
376
+ copyright_tags_dropdown = gr.Dropdown(
377
+ label="Copyright tags",
378
+ choices=tags.get_copyright_tags_list(), # type: ignore
379
+ value=[],
380
+ multiselect=True,
381
+ visible=False,
382
+ )
383
+
384
+ def on_change_copyright_tags_dropdouwn(mode: str):
385
+ kwargs: dict = {"visible": mode == "Custom"}
386
+ if mode == "Original":
387
+ kwargs["value"] = ["original"]
388
+ elif mode == "None":
389
+ kwargs["value"] = []
390
+
391
+ return gr.update(**kwargs)
392
+
393
+ with gr.Group():
394
+ character_tags_mode_dropdown = gr.Dropdown(
395
+ label="Character tags mode",
396
+ choices=[
397
+ "None",
398
+ # "Auto", # TODO: implement these modes
399
+ # "Random",
400
+ "Custom",
401
+ ],
402
+ value="None",
403
+ interactive=True,
404
+ )
405
+ character_tags_dropdown = gr.Dropdown(
406
+ label="Character tags",
407
+ choices=tags.get_character_tags_list(), # type: ignore
408
+ value=[],
409
+ multiselect=True,
410
+ visible=False,
411
+ )
412
+
413
+ def on_change_character_tags_dropdouwn(mode: str):
414
+ kwargs: dict = {"visible": mode == "Custom"}
415
+ if mode == "None":
416
+ kwargs["value"] = []
417
+
418
+ return gr.update(**kwargs)
419
+
420
+ with gr.Group():
421
+ general_tags_textbox = gr.Textbox(
422
+ label="General tags (the condition to generate tags)",
423
+ value="",
424
+ placeholder="1girl, ...",
425
+ lines=4,
426
+ )
427
+
428
+ ban_tags_textbox = gr.Textbox(
429
+ label="Ban tags (tags in this field never appear in generation)",
430
+ value="",
431
+ placeholder="official alternate cosutme, english text,...",
432
+ lines=2,
433
+ )
434
+
435
+ generate_btn = gr.Button("Generate", variant="primary")
436
+
437
+ with gr.Accordion(label="Generation config (advanced)", open=False):
438
+ with gr.Group():
439
+ do_cfg_check = gr.Checkbox(
440
+ label="Do CFG (Classifier Free Guidance)",
441
+ value=False,
442
+ )
443
+ cfg_scale_slider = gr.Slider(
444
+ label="CFG scale",
445
+ maximum=3.0,
446
+ minimum=0.1,
447
+ step=0.1,
448
+ value=1.5,
449
+ visible=False,
450
+ )
451
+ negative_tags_textbox = gr.Textbox(
452
+ label="Negative prompt",
453
+ placeholder="simple background, ...",
454
+ value="",
455
+ lines=2,
456
+ visible=False,
457
+ )
458
+
459
+ def on_change_do_cfg_check(do_cfg: bool):
460
+ kwargs: dict = {"visible": do_cfg}
461
+ return gr.update(**kwargs), gr.update(**kwargs)
462
+
463
+ do_cfg_check.change(
464
+ on_change_do_cfg_check,
465
+ inputs=[do_cfg_check],
466
+ outputs=[cfg_scale_slider, negative_tags_textbox],
467
+ )
468
+
469
+ with gr.Group():
470
+ total_token_length_radio = gr.Radio(
471
+ label="Total token length",
472
+ choices=list(tags.get_length_tags().keys()),
473
+ value="long",
474
+ )
475
+
476
+ with gr.Group():
477
+ max_new_tokens_slider = gr.Slider(
478
+ label="Max new tokens",
479
+ maximum=256,
480
+ minimum=1,
481
+ step=1,
482
+ value=128,
483
+ )
484
+ min_new_tokens_slider = gr.Slider(
485
+ label="Min new tokens",
486
+ maximum=255,
487
+ minimum=0,
488
+ step=1,
489
+ value=0,
490
+ )
491
+ temperature_slider = gr.Slider(
492
+ label="Temperature (larger is more random)",
493
+ maximum=1.0,
494
+ minimum=0.0,
495
+ step=0.1,
496
+ value=1.0,
497
+ )
498
+ top_p_slider = gr.Slider(
499
+ label="Top p (larger is more random)",
500
+ maximum=1.0,
501
+ minimum=0.0,
502
+ step=0.1,
503
+ value=1.0,
504
+ )
505
+ top_k_slider = gr.Slider(
506
+ label="Top k (larger is more random)",
507
+ maximum=500,
508
+ minimum=1,
509
+ step=1,
510
+ value=100,
511
+ )
512
+ num_beams_slider = gr.Slider(
513
+ label="Number of beams (smaller is more random)",
514
+ maximum=10,
515
+ minimum=1,
516
+ step=1,
517
+ value=1,
518
+ )
519
+
520
+ with gr.Column():
521
+ with gr.Group():
522
+ output_tags_natural = gr.Textbox(
523
+ label="Generation result",
524
+ # placeholder="tags will be here",
525
+ interactive=False,
526
+ )
527
+ output_tags_natural_copy_btn = gr.Button("Copy", visible=False)
528
+ output_tags_natural_copy_btn.click(
529
+ fn=tags.copy_text,
530
+ inputs=[output_tags_natural],
531
+ js=COPY_ACTION_JS,
532
+ )
533
+
534
+ with gr.Group():
535
+ output_tags_general_only = gr.Textbox(
536
+ label="General tags only (sorted)",
537
+ interactive=False,
538
+ )
539
+ output_tags_general_only_copy_btn = gr.Button("Copy", visible=False)
540
+ output_tags_general_only_copy_btn.click(
541
+ fn=tags.copy_text,
542
+ inputs=[output_tags_general_only],
543
+ js=COPY_ACTION_JS,
544
+ )
545
+
546
+ with gr.Group():
547
+ output_tags_animagine = gr.Textbox(
548
+ label="Output tags (AnimagineXL v3 style order)",
549
+ # placeholder="tags will be here in Animagine v3 style order",
550
+ interactive=False,
551
+ )
552
+ output_tags_animagine_copy_btn = gr.Button("Copy", visible=False)
553
+ output_tags_animagine_copy_btn.click(
554
+ fn=tags.copy_text,
555
+ inputs=[output_tags_animagine],
556
+ js=COPY_ACTION_JS,
557
+ )
558
+
559
+ with gr.Accordion(label="Metadata", open=False):
560
+ _model_backend_md = gr.Markdown(
561
+ f"Model backend: {tags.get_model_backend()}",
562
+ )
563
+ input_prompt_raw = gr.Textbox(
564
+ label="Input prompt (raw)",
565
+ interactive=False,
566
+ lines=4,
567
+ )
568
+
569
+ output_tags_raw = gr.Textbox(
570
+ label="Output tags (raw)",
571
+ interactive=False,
572
+ lines=4,
573
+ )
574
+
575
+ elapsed_time_md = gr.Markdown(value="Waiting to generate...")
576
+
577
+ copyright_tags_mode_dropdown.change(
578
+ on_change_copyright_tags_dropdouwn,
579
+ inputs=[copyright_tags_mode_dropdown],
580
+ outputs=[copyright_tags_dropdown],
581
+ )
582
+ character_tags_mode_dropdown.change(
583
+ on_change_character_tags_dropdouwn,
584
+ inputs=[character_tags_mode_dropdown],
585
+ outputs=[character_tags_dropdown],
586
+ )
587
+
588
+ generate_btn.click(
589
+ tags.handle_inputs,
590
+ inputs=[
591
+ rating_dropdown,
592
+ copyright_tags_dropdown,
593
+ character_tags_dropdown,
594
+ general_tags_textbox,
595
+ ban_tags_textbox,
596
+ do_cfg_check,
597
+ cfg_scale_slider,
598
+ negative_tags_textbox,
599
+ total_token_length_radio,
600
+ max_new_tokens_slider,
601
+ min_new_tokens_slider,
602
+ temperature_slider,
603
+ top_p_slider,
604
+ top_k_slider,
605
+ num_beams_slider,
606
+ # model_backend_radio,
607
+ ],
608
+ outputs=[
609
+ output_tags_natural,
610
+ output_tags_general_only,
611
+ output_tags_animagine,
612
+ input_prompt_raw,
613
+ output_tags_raw,
614
+ elapsed_time_md,
615
+ output_tags_natural_copy_btn,
616
+ output_tags_general_only_copy_btn,
617
+ output_tags_animagine_copy_btn,
618
+ ],
619
+ )
620
+
621
+ gr.Examples(
622
+ examples=[
623
+ ["1girl, solo, from side", ""],
624
+ ["1girl, solo, abstract, from above", ""],
625
+ ["2girls, yuri", "1boy"],
626
+ ["no humans, scenery, summer, day", ""],
627
+ ],
628
+ inputs=[
629
+ general_tags_textbox,
630
+ ban_tags_textbox,
631
+ ],
632
+ )
633
+
634
  with gr.Column(scale=3):
635
  with gr.Blocks():
636
  run_button = gr.Button("Generate", variant="primary")
requirements.txt CHANGED
@@ -8,4 +8,6 @@ torch==2.0.1
8
  transformers==4.38.1
9
  omegaconf==2.3.0
10
  timm==0.9.10
11
- optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git
 
 
 
8
  transformers==4.38.1
9
  omegaconf==2.3.0
10
  timm==0.9.10
11
+ #optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git
12
+ transformers==4.38.0
13
+ optimum[onnxruntime]==1.17.1
tags.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import time
3
+ import os
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ from optimum.onnxruntime import ORTModelForCausalLM
9
+
10
+
11
+ import gradio as gr
12
+
13
+ MODEL_NAME = (
14
+ os.environ.get("MODEL_NAME")
15
+ if os.environ.get("MODEL_NAME") is not None
16
+ else "p1atdev/dart-v1-sft"
17
+ )
18
+ HF_READ_TOKEN = os.environ.get("HF_READ_TOKEN")
19
+ MODEL_BACKEND = (
20
+ os.environ.get("MODEL_BACKEND")
21
+ if os.environ.get("MODEL_BACKEND") is not None
22
+ else "ONNX (quantized)"
23
+ )
24
+
25
+ assert isinstance(MODEL_NAME, str)
26
+ assert isinstance(MODEL_BACKEND, str)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ MODEL_NAME,
30
+ trust_remote_code=True,
31
+ token=HF_READ_TOKEN,
32
+ )
33
+ model = {
34
+ "default": AutoModelForCausalLM.from_pretrained(
35
+ MODEL_NAME,
36
+ token=HF_READ_TOKEN,
37
+ ),
38
+ "ort": ORTModelForCausalLM.from_pretrained(
39
+ MODEL_NAME,
40
+ ),
41
+ "ort_qantized": ORTModelForCausalLM.from_pretrained(
42
+ MODEL_NAME,
43
+ file_name="model_quantized.onnx",
44
+ ),
45
+ }
46
+
47
+ MODEL_BACKEND_MAP = {
48
+ "Default": "default",
49
+ "ONNX (normal)": "ort",
50
+ "ONNX (quantized)": "ort_qantized",
51
+ }
52
+
53
+ try:
54
+ model["default"].to("cuda")
55
+ except:
56
+ print("No GPU")
57
+
58
+ try:
59
+ model["default"] = torch.compile(model["default"])
60
+ except:
61
+ print("torch.compile is not supported")
62
+
63
+ BOS = "<|bos|>"
64
+ EOS = "<|eos|>"
65
+ RATING_BOS = "<rating>"
66
+ RATING_EOS = "</rating>"
67
+ COPYRIGHT_BOS = "<copyright>"
68
+ COPYRIGHT_EOS = "</copyright>"
69
+ CHARACTER_BOS = "<character>"
70
+ CHARACTER_EOS = "</character>"
71
+ GENERAL_BOS = "<general>"
72
+ GENERAL_EOS = "</general>"
73
+
74
+ INPUT_END = "<|input_end|>"
75
+
76
+ LENGTH_VERY_SHORT = "<|very_short|>"
77
+ LENGTH_SHORT = "<|short|>"
78
+ LENGTH_LONG = "<|long|>"
79
+ LENGTH_VERY_LONG = "<|very_long|>"
80
+
81
+ RATING_BOS_ID = tokenizer.convert_tokens_to_ids(RATING_BOS)
82
+ RATING_EOS_ID = tokenizer.convert_tokens_to_ids(RATING_EOS)
83
+ COPYRIGHT_BOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_BOS)
84
+ COPYRIGHT_EOS_ID = tokenizer.convert_tokens_to_ids(COPYRIGHT_EOS)
85
+ CHARACTER_BOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_BOS)
86
+ CHARACTER_EOS_ID = tokenizer.convert_tokens_to_ids(CHARACTER_EOS)
87
+ GENERAL_BOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_BOS)
88
+ GENERAL_EOS_ID = tokenizer.convert_tokens_to_ids(GENERAL_EOS)
89
+
90
+ assert isinstance(RATING_BOS_ID, int)
91
+ assert isinstance(RATING_EOS_ID, int)
92
+ assert isinstance(COPYRIGHT_BOS_ID, int)
93
+ assert isinstance(COPYRIGHT_EOS_ID, int)
94
+ assert isinstance(CHARACTER_BOS_ID, int)
95
+ assert isinstance(CHARACTER_EOS_ID, int)
96
+ assert isinstance(GENERAL_BOS_ID, int)
97
+ assert isinstance(GENERAL_EOS_ID, int)
98
+
99
+ SPECIAL_TAGS = [
100
+ BOS,
101
+ EOS,
102
+ RATING_BOS,
103
+ RATING_EOS,
104
+ COPYRIGHT_BOS,
105
+ COPYRIGHT_EOS,
106
+ CHARACTER_BOS,
107
+ CHARACTER_EOS,
108
+ GENERAL_BOS,
109
+ GENERAL_EOS,
110
+ INPUT_END,
111
+ LENGTH_VERY_SHORT,
112
+ LENGTH_SHORT,
113
+ LENGTH_LONG,
114
+ LENGTH_VERY_LONG,
115
+ ]
116
+
117
+ SPECIAL_TAG_IDS = tokenizer.convert_tokens_to_ids(SPECIAL_TAGS)
118
+ assert isinstance(SPECIAL_TAG_IDS, list)
119
+ assert all([token_id != tokenizer.unk_token_id for token_id in SPECIAL_TAG_IDS])
120
+
121
+ RATING_TAGS = {
122
+ "sfw": "rating:sfw",
123
+ "nsfw": "rating:nsfw",
124
+ "general": "rating:general",
125
+ "sensitive": "rating:sensitive",
126
+ "questionable": "rating:questionable",
127
+ "explicit": "rating:explicit",
128
+ }
129
+ RATING_TAG_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in RATING_TAGS.items()}
130
+
131
+ LENGTH_TAGS = {
132
+ "very short": LENGTH_VERY_SHORT,
133
+ "short": LENGTH_SHORT,
134
+ "long": LENGTH_LONG,
135
+ "very long": LENGTH_VERY_LONG,
136
+ }
137
+
138
+
139
+ def load_tags(path: str | Path):
140
+ if isinstance(path, str):
141
+ path = Path(path)
142
+
143
+ with open(path, "r", encoding="utf-8") as file:
144
+ lines = [line.strip() for line in file.readlines() if line.strip() != ""]
145
+
146
+ return lines
147
+
148
+
149
+ COPYRIGHT_TAGS_LIST: list[str] = load_tags("./tags/copyright.txt")
150
+ CHARACTER_TAGS_LIST: list[str] = load_tags("./tags/character.txt")
151
+ PEOPLE_TAGS_LIST: list[str] = load_tags("./tags/people.txt")
152
+
153
+ PEOPLE_TAG_IDS_LIST = tokenizer.convert_tokens_to_ids(PEOPLE_TAGS_LIST)
154
+
155
+ assert isinstance(PEOPLE_TAG_IDS_LIST, list)
156
+
157
+
158
+ @torch.no_grad()
159
+ def generate(
160
+ input_text: str,
161
+ model_backend: str,
162
+ max_new_tokens: int = 128,
163
+ min_new_tokens: int = 0,
164
+ do_sample: bool = True,
165
+ temperature: float = 1.0,
166
+ top_p: float = 1,
167
+ top_k: int = 20,
168
+ num_beams: int = 1,
169
+ bad_words_ids: list[int] | None = None,
170
+ cfg_scale: float = 1.5,
171
+ negative_input_text: str | None = None,
172
+ ) -> list[int]:
173
+ inputs = tokenizer(
174
+ input_text,
175
+ return_tensors="pt",
176
+ ).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
177
+ negative_inputs = (
178
+ tokenizer(
179
+ negative_input_text,
180
+ return_tensors="pt",
181
+ ).input_ids.to(model[MODEL_BACKEND_MAP[model_backend]].device)
182
+ if negative_input_text is not None
183
+ else None
184
+ )
185
+
186
+ generated = model[MODEL_BACKEND_MAP[model_backend]].generate(
187
+ inputs,
188
+ max_new_tokens=max_new_tokens,
189
+ min_new_tokens=min_new_tokens,
190
+ do_sample=do_sample,
191
+ temperature=temperature,
192
+ top_p=top_p,
193
+ top_k=top_k,
194
+ num_beams=num_beams,
195
+ bad_words_ids=(
196
+ [[token] for token in bad_words_ids] if bad_words_ids is not None else None
197
+ ),
198
+ negative_prompt_ids=negative_inputs,
199
+ guidance_scale=cfg_scale,
200
+ no_repeat_ngram_size=1,
201
+ )[0]
202
+
203
+ return generated.tolist()
204
+
205
+
206
+ def decode_normal(token_ids: list[int], skip_special_tokens: bool = True):
207
+ return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
208
+
209
+
210
+ def decode_general_only(token_ids: list[int]):
211
+ token_ids = token_ids[token_ids.index(GENERAL_BOS_ID) :]
212
+ decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
213
+ tags = [tag for tag in decoded.split(", ")]
214
+ tags = sorted(tags)
215
+ return ", ".join(tags)
216
+
217
+
218
+ def split_people_tokens_part(token_ids: list[int]):
219
+ people_tokens = []
220
+ other_tokens = []
221
+
222
+ for token in token_ids:
223
+ if token in PEOPLE_TAG_IDS_LIST:
224
+ people_tokens.append(token)
225
+ else:
226
+ other_tokens.append(token)
227
+
228
+ return people_tokens, other_tokens
229
+
230
+
231
+ def decode_animagine(token_ids: list[int]):
232
+ def get_part(eos_token_id: int, remains_part: list[int]):
233
+ part = []
234
+ for i, token_id in enumerate(remains_part):
235
+ if token_id == eos_token_id:
236
+ return part, remains_part[i:]
237
+
238
+ part.append(token_id)
239
+
240
+ raise Exception("The provided EOS token was not found in the token_ids.")
241
+
242
+ # get each part
243
+ rating_part, remains = get_part(RATING_EOS_ID, token_ids)
244
+ copyright_part, remains = get_part(COPYRIGHT_EOS_ID, remains)
245
+ character_part, remains = get_part(CHARACTER_EOS_ID, remains)
246
+ general_part, _ = get_part(GENERAL_EOS_ID, remains)
247
+
248
+ # separete people tags (1girl, 1boy, no humans...)
249
+ people_part, other_general_part = split_people_tokens_part(general_part)
250
+
251
+ # remove "rating:sfw"
252
+ rating_part = [token for token in rating_part if token != RATING_TAG_IDS["sfw"]]
253
+
254
+ # AnimagineXL v3 style order
255
+ rearranged_tokens = (
256
+ people_part + character_part + copyright_part + other_general_part + rating_part
257
+ )
258
+ rearranged_tokens = [
259
+ token for token in rearranged_tokens if token not in SPECIAL_TAG_IDS
260
+ ]
261
+
262
+ decoded = tokenizer.decode(rearranged_tokens, skip_special_tokens=True)
263
+
264
+ # fix "nsfw" tag
265
+ decoded = decoded.replace("rating:nsfw", "nsfw")
266
+
267
+ return decoded
268
+
269
+
270
+ def prepare_rating_tags(rating: str):
271
+ tag = RATING_TAGS[rating]
272
+ if tag in [RATING_TAGS["general"], RATING_TAGS["sensitive"]]:
273
+ parent = RATING_TAGS["sfw"]
274
+ else:
275
+ parent = RATING_TAGS["nsfw"]
276
+
277
+ return f"{parent}, {tag}"
278
+
279
+
280
+ def handle_inputs(
281
+ rating_tags: str,
282
+ copyright_tags_list: list[str],
283
+ character_tags_list: list[str],
284
+ general_tags: str,
285
+ ban_tags: str,
286
+ do_cfg: bool = False,
287
+ cfg_scale: float = 1.5,
288
+ negative_tags: str = "",
289
+ total_token_length: str = "long",
290
+ max_new_tokens: int = 128,
291
+ min_new_tokens: int = 0,
292
+ temperature: float = 1.0,
293
+ top_p: float = 1.0,
294
+ top_k: int = 20,
295
+ num_beams: int = 1,
296
+ # model_backend: str = "Default",
297
+ ):
298
+ """
299
+ Returns:
300
+ [
301
+ output_tags_natural,
302
+ output_tags_general_only,
303
+ output_tags_animagine,
304
+ input_prompt_raw,
305
+ output_tags_raw,
306
+ elapsed_time,
307
+ output_tags_natural_copy_btn,
308
+ output_tags_general_only_copy_btn,
309
+ output_tags_animagine_copy_btn
310
+ ]
311
+ """
312
+
313
+ start_time = time.time()
314
+
315
+ copyright_tags = ", ".join(copyright_tags_list)
316
+ character_tags = ", ".join(character_tags_list)
317
+
318
+ token_length_tag = LENGTH_TAGS[total_token_length]
319
+
320
+ prompt: str = tokenizer.apply_chat_template(
321
+ { # type: ignore
322
+ "rating": prepare_rating_tags(rating_tags),
323
+ "copyright": copyright_tags,
324
+ "character": character_tags,
325
+ "general": general_tags,
326
+ "length": token_length_tag,
327
+ },
328
+ tokenize=False,
329
+ )
330
+
331
+ negative_prompt: str = tokenizer.apply_chat_template(
332
+ { # type: ignore
333
+ "rating": prepare_rating_tags(rating_tags),
334
+ "copyright": "",
335
+ "character": "",
336
+ "general": negative_tags,
337
+ "length": token_length_tag,
338
+ },
339
+ tokenize=False,
340
+ )
341
+
342
+ bad_words_ids = tokenizer.encode_plus(
343
+ ban_tags if negative_tags.strip() == "" else ban_tags + ", " + negative_tags
344
+ ).input_ids
345
+
346
+ generated_ids = generate(
347
+ prompt,
348
+ model_backend=MODEL_BACKEND,
349
+ max_new_tokens=max_new_tokens,
350
+ min_new_tokens=min_new_tokens,
351
+ do_sample=True,
352
+ temperature=temperature,
353
+ top_p=top_p,
354
+ top_k=top_k,
355
+ num_beams=num_beams,
356
+ bad_words_ids=bad_words_ids if len(bad_words_ids) > 0 else None,
357
+ cfg_scale=cfg_scale,
358
+ negative_input_text=negative_prompt if do_cfg else None,
359
+ )
360
+
361
+ decoded_normal = decode_normal(generated_ids, skip_special_tokens=True)
362
+ decoded_general_only = decode_general_only(generated_ids)
363
+ decoded_animagine = decode_animagine(generated_ids)
364
+ decoded_raw = decode_normal(generated_ids, skip_special_tokens=False)
365
+
366
+ end_time = time.time()
367
+ elapsed_time = f"Elapsed: {(end_time - start_time) * 1000:.2f} ms"
368
+
369
+ # update visibility of buttons
370
+ set_visible = gr.update(visible=True)
371
+
372
+ return [
373
+ decoded_normal,
374
+ decoded_general_only,
375
+ decoded_animagine,
376
+ prompt,
377
+ decoded_raw,
378
+ elapsed_time,
379
+ set_visible,
380
+ set_visible,
381
+ set_visible,
382
+ ]
383
+
384
+
385
+ # ref: https://qiita.com/tregu148/items/fccccbbc47d966dd2fc2
386
+ def copy_text(_text: None):
387
+ gr.Info("Copied!")
388
+
389
+
390
+ def get_model_backend():
391
+ return MODEL_BACKEND
392
+
393
+ def get_length_tags():
394
+ return LENGTH_TAGS
395
+
396
+ def get_copyright_tags_list():
397
+ return COPYRIGHT_TAGS_LIST
398
+
399
+ def get_character_tags_list():
400
+ return CHARACTER_TAGS_LIST