AlekseyCalvin commited on
Commit
ad89ac3
·
verified ·
1 Parent(s): 07ec991

Delete modutils.py

Browse files
Files changed (1) hide show
  1. modutils.py +0 -1290
modutils.py DELETED
@@ -1,1290 +0,0 @@
1
- import spaces
2
- import json
3
- import gradio as gr
4
- from huggingface_hub import HfApi
5
- import os
6
- from pathlib import Path
7
- from PIL import Image
8
-
9
-
10
- from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
11
- HF_MODEL_USER_EX, HF_MODEL_USER_LIKES, DIFFUSERS_FORMAT_LORAS,
12
- directory_loras, hf_read_token, HF_TOKEN, CIVITAI_API_KEY)
13
-
14
-
15
- MODEL_TYPE_DICT = {
16
- "diffusers:StableDiffusionPipeline": "SD 1.5",
17
- "diffusers:StableDiffusionXLPipeline": "SDXL",
18
- "diffusers:FluxPipeline": "FLUX",
19
- }
20
-
21
-
22
- def get_user_agent():
23
- return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
24
-
25
-
26
- def to_list(s):
27
- return [x.strip() for x in s.split(",") if not s == ""]
28
-
29
-
30
- def list_uniq(l):
31
- return sorted(set(l), key=l.index)
32
-
33
-
34
- def list_sub(a, b):
35
- return [e for e in a if e not in b]
36
-
37
-
38
- def is_repo_name(s):
39
- import re
40
- return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
41
-
42
-
43
- from translatepy import Translator
44
- translator = Translator()
45
- def translate_to_en(input: str):
46
- try:
47
- output = str(translator.translate(input, 'English'))
48
- except Exception as e:
49
- output = input
50
- print(e)
51
- return output
52
-
53
-
54
- def get_local_model_list(dir_path):
55
- model_list = []
56
- valid_extensions = ('.ckpt', '.pt', '.pth', '.safetensors', '.bin')
57
- for file in Path(dir_path).glob("*"):
58
- if file.suffix in valid_extensions:
59
- file_path = str(Path(f"{dir_path}/{file.name}"))
60
- model_list.append(file_path)
61
- return model_list
62
-
63
-
64
- def download_things(directory, url, hf_token="", civitai_api_key=""):
65
- url = url.strip()
66
- if "drive.google.com" in url:
67
- original_dir = os.getcwd()
68
- os.chdir(directory)
69
- os.system(f"gdown --fuzzy {url}")
70
- os.chdir(original_dir)
71
- elif "huggingface.co" in url:
72
- url = url.replace("?download=true", "")
73
- # url = urllib.parse.quote(url, safe=':/') # fix encoding
74
- if "/blob/" in url:
75
- url = url.replace("/blob/", "/resolve/")
76
- user_header = f'"Authorization: Bearer {hf_token}"'
77
- if hf_token:
78
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
- else:
80
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
- elif "civitai.com" in url:
82
- if "?" in url:
83
- url = url.split("?")[0]
84
- if civitai_api_key:
85
- url = url + f"?token={civitai_api_key}"
86
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
87
- else:
88
- print("\033[91mYou need an API key to download Civitai models.\033[0m")
89
- else:
90
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
91
-
92
-
93
- def escape_lora_basename(basename: str):
94
- return basename.replace(".", "_").replace(" ", "_").replace(",", "")
95
-
96
-
97
- def to_lora_key(path: str):
98
- return escape_lora_basename(Path(path).stem)
99
-
100
-
101
- def to_lora_path(key: str):
102
- if Path(key).is_file(): return key
103
- path = Path(f"{directory_loras}/{escape_lora_basename(key)}.safetensors")
104
- return str(path)
105
-
106
-
107
- def safe_float(input):
108
- output = 1.0
109
- try:
110
- output = float(input)
111
- except Exception:
112
- output = 1.0
113
- return output
114
-
115
-
116
- def save_images(images: list[Image.Image], metadatas: list[str]):
117
- from PIL import PngImagePlugin
118
- import uuid
119
- try:
120
- output_images = []
121
- for image, metadata in zip(images, metadatas):
122
- info = PngImagePlugin.PngInfo()
123
- info.add_text("parameters", metadata)
124
- savefile = f"{str(uuid.uuid4())}.png"
125
- image.save(savefile, "PNG", pnginfo=info)
126
- output_images.append(str(Path(savefile).resolve()))
127
- return output_images
128
- except Exception as e:
129
- print(f"Failed to save image file: {e}")
130
- raise Exception(f"Failed to save image file:") from e
131
-
132
-
133
- def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
134
- from datetime import datetime, timezone, timedelta
135
- progress(0, desc="Updating gallery...")
136
- dt_now = datetime.now(timezone(timedelta(hours=9)))
137
- basename = dt_now.strftime('%Y%m%d_%H%M%S_')
138
- i = 1
139
- if not images: return images, gr.update(visible=False)
140
- output_images = []
141
- output_paths = []
142
- for image in images:
143
- filename = basename + str(i) + ".png"
144
- i += 1
145
- oldpath = Path(image[0])
146
- newpath = oldpath
147
- try:
148
- if oldpath.exists():
149
- newpath = oldpath.resolve().rename(Path(filename).resolve())
150
- except Exception as e:
151
- print(e)
152
- finally:
153
- output_paths.append(str(newpath))
154
- output_images.append((str(newpath), str(filename)))
155
- progress(1, desc="Gallery updated.")
156
- return gr.update(value=output_images), gr.update(value=output_paths, visible=True)
157
-
158
-
159
- def download_private_repo(repo_id, dir_path, is_replace):
160
- from huggingface_hub import snapshot_download
161
- if not hf_read_token: return
162
- try:
163
- snapshot_download(repo_id=repo_id, local_dir=dir_path, allow_patterns=['*.ckpt', '*.pt', '*.pth', '*.safetensors', '*.bin'], use_auth_token=hf_read_token)
164
- except Exception as e:
165
- print(f"Error: Failed to download {repo_id}.")
166
- print(e)
167
- return
168
- if is_replace:
169
- for file in Path(dir_path).glob("*"):
170
- if file.exists() and "." in file.stem or " " in file.stem and file.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
171
- newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}')
172
- file.resolve().rename(newpath.resolve())
173
-
174
-
175
- private_model_path_repo_dict = {} # {"local filepath": "huggingface repo_id", ...}
176
-
177
-
178
- def get_private_model_list(repo_id, dir_path):
179
- global private_model_path_repo_dict
180
- api = HfApi()
181
- if not hf_read_token: return []
182
- try:
183
- files = api.list_repo_files(repo_id, token=hf_read_token)
184
- except Exception as e:
185
- print(f"Error: Failed to list {repo_id}.")
186
- print(e)
187
- return []
188
- model_list = []
189
- for file in files:
190
- path = Path(f"{dir_path}/{file}")
191
- if path.suffix in ['.ckpt', '.pt', '.pth', '.safetensors', '.bin']:
192
- model_list.append(str(path))
193
- for model in model_list:
194
- private_model_path_repo_dict[model] = repo_id
195
- return model_list
196
-
197
-
198
- def download_private_file(repo_id, path, is_replace):
199
- from huggingface_hub import hf_hub_download
200
- file = Path(path)
201
- newpath = Path(f'{file.parent.name}/{escape_lora_basename(file.stem)}{file.suffix}') if is_replace else file
202
- if not hf_read_token or newpath.exists(): return
203
- filename = file.name
204
- dirname = file.parent.name
205
- try:
206
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=dirname, use_auth_token=hf_read_token)
207
- except Exception as e:
208
- print(f"Error: Failed to download {filename}.")
209
- print(e)
210
- return
211
- if is_replace:
212
- file.resolve().rename(newpath.resolve())
213
-
214
-
215
- def download_private_file_from_somewhere(path, is_replace):
216
- if not path in private_model_path_repo_dict.keys(): return
217
- repo_id = private_model_path_repo_dict.get(path, None)
218
- download_private_file(repo_id, path, is_replace)
219
-
220
-
221
- model_id_list = []
222
- def get_model_id_list():
223
- global model_id_list
224
- if len(model_id_list) != 0: return model_id_list
225
- api = HfApi()
226
- model_ids = []
227
- try:
228
- models_likes = []
229
- for author in HF_MODEL_USER_LIKES:
230
- models_likes.extend(api.list_models(author=author, task="text-to-image", cardData=True, sort="likes"))
231
- models_ex = []
232
- for author in HF_MODEL_USER_EX:
233
- models_ex = api.list_models(author=author, task="text-to-image", cardData=True, sort="last_modified")
234
- except Exception as e:
235
- print(f"Error: Failed to list {author}'s models.")
236
- print(e)
237
- return model_ids
238
- for model in models_likes:
239
- model_ids.append(model.id) if not model.private else ""
240
- anime_models = []
241
- real_models = []
242
- anime_models_flux = []
243
- real_models_flux = []
244
- for model in models_ex:
245
- if not model.private and not model.gated:
246
- if "diffusers:FluxPipeline" in model.tags: anime_models_flux.append(model.id) if "anime" in model.tags else real_models_flux.append(model.id)
247
- else: anime_models.append(model.id) if "anime" in model.tags else real_models.append(model.id)
248
- model_ids.extend(anime_models)
249
- model_ids.extend(real_models)
250
- model_ids.extend(anime_models_flux)
251
- model_ids.extend(real_models_flux)
252
- model_id_list = model_ids.copy()
253
- return model_ids
254
-
255
-
256
- model_id_list = get_model_id_list()
257
-
258
-
259
- def get_t2i_model_info(repo_id: str):
260
- api = HfApi(token=HF_TOKEN)
261
- try:
262
- if not is_repo_name(repo_id): return ""
263
- model = api.model_info(repo_id=repo_id, timeout=5.0)
264
- except Exception as e:
265
- print(f"Error: Failed to get {repo_id}'s info.")
266
- print(e)
267
- return ""
268
- if model.private or model.gated: return ""
269
- tags = model.tags
270
- info = []
271
- url = f"https://huggingface.co/{repo_id}/"
272
- if not 'diffusers' in tags: return ""
273
- for k, v in MODEL_TYPE_DICT.items():
274
- if k in tags: info.append(v)
275
- if model.card_data and model.card_data.tags:
276
- info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
277
- info.append(f"DLs: {model.downloads}")
278
- info.append(f"likes: {model.likes}")
279
- info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
280
- md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
281
- return gr.update(value=md)
282
-
283
-
284
- def get_tupled_model_list(model_list):
285
- if not model_list: return []
286
- tupled_list = []
287
- for repo_id in model_list:
288
- api = HfApi()
289
- try:
290
- if not api.repo_exists(repo_id): continue
291
- model = api.model_info(repo_id=repo_id)
292
- except Exception as e:
293
- print(e)
294
- continue
295
- if model.private or model.gated: continue
296
- tags = model.tags
297
- info = []
298
- if not 'diffusers' in tags: continue
299
- for k, v in MODEL_TYPE_DICT.items():
300
- if k in tags: info.append(v)
301
- if model.card_data and model.card_data.tags:
302
- info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
303
- if "pony" in info:
304
- info.remove("pony")
305
- name = f"{repo_id} (Pony🐴, {', '.join(info)})"
306
- else:
307
- name = f"{repo_id} ({', '.join(info)})"
308
- tupled_list.append((name, repo_id))
309
- return tupled_list
310
-
311
-
312
- private_lora_dict = {}
313
- try:
314
- with open('lora_dict.json', encoding='utf-8') as f:
315
- d = json.load(f)
316
- for k, v in d.items():
317
- private_lora_dict[escape_lora_basename(k)] = v
318
- except Exception as e:
319
- print(e)
320
- loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
321
- civitai_not_exists_list = []
322
- loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...}
323
- civitai_lora_last_results = {} # {"URL to download": {search results}, ...}
324
- all_lora_list = []
325
-
326
-
327
- private_lora_model_list = []
328
- def get_private_lora_model_lists():
329
- global private_lora_model_list
330
- if len(private_lora_model_list) != 0: return private_lora_model_list
331
- models1 = []
332
- models2 = []
333
- for repo in HF_LORA_PRIVATE_REPOS1:
334
- models1.extend(get_private_model_list(repo, directory_loras))
335
- for repo in HF_LORA_PRIVATE_REPOS2:
336
- models2.extend(get_private_model_list(repo, directory_loras))
337
- models = list_uniq(models1 + sorted(models2))
338
- private_lora_model_list = models.copy()
339
- return models
340
-
341
-
342
- private_lora_model_list = get_private_lora_model_lists()
343
-
344
-
345
- def get_civitai_info(path):
346
- global civitai_not_exists_list
347
- import requests
348
- from urllib3.util import Retry
349
- from requests.adapters import HTTPAdapter
350
- if path in set(civitai_not_exists_list): return ["", "", "", "", ""]
351
- if not Path(path).exists(): return None
352
- user_agent = get_user_agent()
353
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
354
- base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
355
- params = {}
356
- session = requests.Session()
357
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
358
- session.mount("https://", HTTPAdapter(max_retries=retries))
359
- import hashlib
360
- with open(path, 'rb') as file:
361
- file_data = file.read()
362
- hash_sha256 = hashlib.sha256(file_data).hexdigest()
363
- url = base_url + hash_sha256
364
- try:
365
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
366
- except Exception as e:
367
- print(e)
368
- return ["", "", "", "", ""]
369
- if not r.ok: return None
370
- json = r.json()
371
- if not 'baseModel' in json:
372
- civitai_not_exists_list.append(path)
373
- return ["", "", "", "", ""]
374
- items = []
375
- items.append(" / ".join(json['trainedWords']))
376
- items.append(json['baseModel'])
377
- items.append(json['model']['name'])
378
- items.append(f"https://civitai.com/models/{json['modelId']}")
379
- items.append(json['images'][0]['url'])
380
- return items
381
-
382
-
383
- def get_lora_model_list():
384
- loras = list_uniq(get_private_lora_model_lists() + get_local_model_list(directory_loras) + DIFFUSERS_FORMAT_LORAS)
385
- loras.insert(0, "None")
386
- loras.insert(0, "")
387
- return loras
388
-
389
-
390
- def get_all_lora_list():
391
- global all_lora_list
392
- loras = get_lora_model_list()
393
- all_lora_list = loras.copy()
394
- return loras
395
-
396
-
397
- def get_all_lora_tupled_list():
398
- global loras_dict
399
- models = get_all_lora_list()
400
- if not models: return []
401
- tupled_list = []
402
- for model in models:
403
- #if not model: continue # to avoid GUI-related bug
404
- basename = Path(model).stem
405
- key = to_lora_key(model)
406
- items = None
407
- if key in loras_dict.keys():
408
- items = loras_dict.get(key, None)
409
- else:
410
- items = get_civitai_info(model)
411
- if items != None:
412
- loras_dict[key] = items
413
- name = basename
414
- value = model
415
- if items and items[2] != "":
416
- if items[1] == "Pony":
417
- name = f"{basename} (for {items[1]}🐴, {items[2]})"
418
- else:
419
- name = f"{basename} (for {items[1]}, {items[2]})"
420
- tupled_list.append((name, value))
421
- return tupled_list
422
-
423
-
424
- def update_lora_dict(path):
425
- global loras_dict
426
- key = escape_lora_basename(Path(path).stem)
427
- if key in loras_dict.keys(): return
428
- items = get_civitai_info(path)
429
- if items == None: return
430
- loras_dict[key] = items
431
-
432
-
433
- def download_lora(dl_urls: str):
434
- global loras_url_to_path_dict
435
- dl_path = ""
436
- before = get_local_model_list(directory_loras)
437
- urls = []
438
- for url in [url.strip() for url in dl_urls.split(',')]:
439
- local_path = f"{directory_loras}/{url.split('/')[-1]}"
440
- if not Path(local_path).exists():
441
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
442
- urls.append(url)
443
- after = get_local_model_list(directory_loras)
444
- new_files = list_sub(after, before)
445
- i = 0
446
- for file in new_files:
447
- path = Path(file)
448
- if path.exists():
449
- new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
450
- path.resolve().rename(new_path.resolve())
451
- loras_url_to_path_dict[urls[i]] = str(new_path)
452
- update_lora_dict(str(new_path))
453
- dl_path = str(new_path)
454
- i += 1
455
- return dl_path
456
-
457
-
458
- def copy_lora(path: str, new_path: str):
459
- import shutil
460
- if path == new_path: return new_path
461
- cpath = Path(path)
462
- npath = Path(new_path)
463
- if cpath.exists():
464
- try:
465
- shutil.copy(str(cpath.resolve()), str(npath.resolve()))
466
- except Exception as e:
467
- print(e)
468
- return None
469
- update_lora_dict(str(npath))
470
- return new_path
471
- else:
472
- return None
473
-
474
-
475
- def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str):
476
- path = download_lora(dl_urls)
477
- if path:
478
- if not lora1 or lora1 == "None":
479
- lora1 = path
480
- elif not lora2 or lora2 == "None":
481
- lora2 = path
482
- elif not lora3 or lora3 == "None":
483
- lora3 = path
484
- elif not lora4 or lora4 == "None":
485
- lora4 = path
486
- elif not lora5 or lora5 == "None":
487
- lora5 = path
488
- choices = get_all_lora_tupled_list()
489
- return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
490
- gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)
491
-
492
-
493
- def get_valid_lora_name(query: str, model_name: str):
494
- path = "None"
495
- if not query or query == "None": return "None"
496
- if to_lora_key(query) in loras_dict.keys(): return query
497
- if query in loras_url_to_path_dict.keys():
498
- path = loras_url_to_path_dict[query]
499
- else:
500
- path = to_lora_path(query.strip().split('/')[-1])
501
- if Path(path).exists():
502
- return path
503
- elif "http" in query:
504
- dl_file = download_lora(query)
505
- if dl_file and Path(dl_file).exists(): return dl_file
506
- else:
507
- dl_file = find_similar_lora(query, model_name)
508
- if dl_file and Path(dl_file).exists(): return dl_file
509
- return "None"
510
-
511
-
512
- def get_valid_lora_path(query: str):
513
- path = None
514
- if not query or query == "None": return None
515
- if to_lora_key(query) in loras_dict.keys(): return query
516
- if Path(path).exists():
517
- return path
518
- else:
519
- return None
520
-
521
-
522
- def get_valid_lora_wt(prompt: str, lora_path: str, lora_wt: float):
523
- import re
524
- wt = lora_wt
525
- result = re.findall(f'<lora:{to_lora_key(lora_path)}:(.+?)>', prompt)
526
- if not result: return wt
527
- wt = safe_float(result[0][0])
528
- return wt
529
-
530
-
531
- def set_prompt_loras(prompt, prompt_syntax, model_name, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
532
- import re
533
- if not "Classic" in str(prompt_syntax): return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
534
- lora1 = get_valid_lora_name(lora1, model_name)
535
- lora2 = get_valid_lora_name(lora2, model_name)
536
- lora3 = get_valid_lora_name(lora3, model_name)
537
- lora4 = get_valid_lora_name(lora4, model_name)
538
- lora5 = get_valid_lora_name(lora5, model_name)
539
- if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
540
- lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
541
- lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
542
- lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
543
- lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
544
- lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
545
- on1, label1, tag1, md1 = get_lora_info(lora1)
546
- on2, label2, tag2, md2 = get_lora_info(lora2)
547
- on3, label3, tag3, md3 = get_lora_info(lora3)
548
- on4, label4, tag4, md4 = get_lora_info(lora4)
549
- on5, label5, tag5, md5 = get_lora_info(lora5)
550
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
551
- prompts = prompt.split(",") if prompt else []
552
- for p in prompts:
553
- p = str(p).strip()
554
- if "<lora" in p:
555
- result = re.findall(r'<lora:(.+?):(.+?)>', p)
556
- if not result: continue
557
- key = result[0][0]
558
- wt = result[0][1]
559
- path = to_lora_path(key)
560
- if not key in loras_dict.keys() or not path:
561
- path = get_valid_lora_name(path)
562
- if not path or path == "None": continue
563
- if path in lora_paths:
564
- continue
565
- elif not on1:
566
- lora1 = path
567
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
568
- lora1_wt = safe_float(wt)
569
- on1 = True
570
- elif not on2:
571
- lora2 = path
572
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
573
- lora2_wt = safe_float(wt)
574
- on2 = True
575
- elif not on3:
576
- lora3 = path
577
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
578
- lora3_wt = safe_float(wt)
579
- on3 = True
580
- elif not on4:
581
- lora4 = path
582
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
583
- lora4_wt = safe_float(wt)
584
- on4, label4, tag4, md4 = get_lora_info(lora4)
585
- elif not on5:
586
- lora5 = path
587
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
588
- lora5_wt = safe_float(wt)
589
- on5 = True
590
- return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
591
-
592
-
593
- def get_lora_info(lora_path: str):
594
- is_valid = False
595
- tag = ""
596
- label = ""
597
- md = "None"
598
- if not lora_path or lora_path == "None":
599
- print("LoRA file not found.")
600
- return is_valid, label, tag, md
601
- path = Path(lora_path)
602
- new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
603
- if not to_lora_key(str(new_path)) in loras_dict.keys() and str(path) not in set(get_all_lora_list()):
604
- print("LoRA file is not registered.")
605
- return tag, label, tag, md
606
- if not new_path.exists():
607
- download_private_file_from_somewhere(str(path), True)
608
- basename = new_path.stem
609
- label = f'Name: {basename}'
610
- items = loras_dict.get(basename, None)
611
- if items == None:
612
- items = get_civitai_info(str(new_path))
613
- if items != None:
614
- loras_dict[basename] = items
615
- if items and items[2] != "":
616
- tag = items[0]
617
- label = f'Name: {basename}'
618
- if items[1] == "Pony":
619
- label = f'Name: {basename} (for Pony🐴)'
620
- if items[4]:
621
- md = f'<img src="{items[4]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL]({items[3]})'
622
- elif items[3]:
623
- md = f'[LoRA Model URL]({items[3]})'
624
- is_valid = True
625
- return is_valid, label, tag, md
626
-
627
-
628
- def normalize_prompt_list(tags: list[str]):
629
- prompts = []
630
- for tag in tags:
631
- tag = str(tag).strip()
632
- if tag:
633
- prompts.append(tag)
634
- return prompts
635
-
636
-
637
- def apply_lora_prompt(prompt: str = "", lora_info: str = ""):
638
- if lora_info == "None": return gr.update(value=prompt)
639
- tags = prompt.split(",") if prompt else []
640
- prompts = normalize_prompt_list(tags)
641
-
642
- lora_tag = lora_info.replace("/",",")
643
- lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
644
- lora_prompts = normalize_prompt_list(lora_tags)
645
-
646
- empty = [""]
647
- prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
648
- return gr.update(value=prompt)
649
-
650
-
651
- def update_loras(prompt, prompt_syntax, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
652
- import re
653
- on1, label1, tag1, md1 = get_lora_info(lora1)
654
- on2, label2, tag2, md2 = get_lora_info(lora2)
655
- on3, label3, tag3, md3 = get_lora_info(lora3)
656
- on4, label4, tag4, md4 = get_lora_info(lora4)
657
- on5, label5, tag5, md5 = get_lora_info(lora5)
658
- lora_paths = [lora1, lora2, lora3, lora4, lora5]
659
-
660
- output_prompt = prompt
661
- if "Classic" in str(prompt_syntax):
662
- prompts = prompt.split(",") if prompt else []
663
- output_prompts = []
664
- for p in prompts:
665
- p = str(p).strip()
666
- if "<lora" in p:
667
- result = re.findall(r'<lora:(.+?):(.+?)>', p)
668
- if not result: continue
669
- key = result[0][0]
670
- wt = result[0][1]
671
- path = to_lora_path(key)
672
- if not key in loras_dict.keys() or not path: continue
673
- if path in lora_paths:
674
- output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>")
675
- elif p:
676
- output_prompts.append(p)
677
- lora_prompts = []
678
- if on1: lora_prompts.append(f"<lora:{to_lora_key(lora1)}:{lora1_wt:.2f}>")
679
- if on2: lora_prompts.append(f"<lora:{to_lora_key(lora2)}:{lora2_wt:.2f}>")
680
- if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
681
- if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
682
- if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
683
- output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
684
- choices = get_all_lora_tupled_list()
685
-
686
- return gr.update(value=output_prompt), gr.update(value=lora1, choices=choices), gr.update(value=lora1_wt),\
687
- gr.update(value=tag1, label=label1, visible=on1), gr.update(visible=on1), gr.update(value=md1, visible=on1),\
688
- gr.update(value=lora2, choices=choices), gr.update(value=lora2_wt),\
689
- gr.update(value=tag2, label=label2, visible=on2), gr.update(visible=on2), gr.update(value=md2, visible=on2),\
690
- gr.update(value=lora3, choices=choices), gr.update(value=lora3_wt),\
691
- gr.update(value=tag3, label=label3, visible=on3), gr.update(visible=on3), gr.update(value=md3, visible=on3),\
692
- gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
693
- gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
694
- gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
695
- gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)
696
-
697
-
698
- def get_my_lora(link_url):
699
- from pathlib import Path
700
- before = get_local_model_list(directory_loras)
701
- for url in [url.strip() for url in link_url.split(',')]:
702
- if not Path(f"{directory_loras}/{url.split('/')[-1]}").exists():
703
- download_things(directory_loras, url, HF_TOKEN, CIVITAI_API_KEY)
704
- after = get_local_model_list(directory_loras)
705
- new_files = list_sub(after, before)
706
- for file in new_files:
707
- path = Path(file)
708
- if path.exists():
709
- new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
710
- path.resolve().rename(new_path.resolve())
711
- update_lora_dict(str(new_path))
712
- new_lora_model_list = get_lora_model_list()
713
- new_lora_tupled_list = get_all_lora_tupled_list()
714
-
715
- return gr.update(
716
- choices=new_lora_tupled_list, value=new_lora_model_list[-1]
717
- ), gr.update(
718
- choices=new_lora_tupled_list
719
- ), gr.update(
720
- choices=new_lora_tupled_list
721
- ), gr.update(
722
- choices=new_lora_tupled_list
723
- ), gr.update(
724
- choices=new_lora_tupled_list
725
- )
726
-
727
-
728
- def upload_file_lora(files, progress=gr.Progress(track_tqdm=True)):
729
- progress(0, desc="Uploading...")
730
- file_paths = [file.name for file in files]
731
- progress(1, desc="Uploaded.")
732
- return gr.update(value=file_paths, visible=True), gr.update(visible=True)
733
-
734
-
735
- def move_file_lora(filepaths):
736
- import shutil
737
- for file in filepaths:
738
- path = Path(shutil.move(Path(file).resolve(), Path(f"./{directory_loras}").resolve()))
739
- newpath = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
740
- path.resolve().rename(newpath.resolve())
741
- update_lora_dict(str(newpath))
742
-
743
- new_lora_model_list = get_lora_model_list()
744
- new_lora_tupled_list = get_all_lora_tupled_list()
745
-
746
- return gr.update(
747
- choices=new_lora_tupled_list, value=new_lora_model_list[-1]
748
- ), gr.update(
749
- choices=new_lora_tupled_list
750
- ), gr.update(
751
- choices=new_lora_tupled_list
752
- ), gr.update(
753
- choices=new_lora_tupled_list
754
- ), gr.update(
755
- choices=new_lora_tupled_list
756
- )
757
-
758
-
759
- def get_civitai_info(path):
760
- global civitai_not_exists_list, loras_url_to_path_dict
761
- import requests
762
- from requests.adapters import HTTPAdapter
763
- from urllib3.util import Retry
764
- default = ["", "", "", "", ""]
765
- if path in set(civitai_not_exists_list): return default
766
- if not Path(path).exists(): return None
767
- user_agent = get_user_agent()
768
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
769
- base_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
770
- params = {}
771
- session = requests.Session()
772
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
773
- session.mount("https://", HTTPAdapter(max_retries=retries))
774
- import hashlib
775
- with open(path, 'rb') as file:
776
- file_data = file.read()
777
- hash_sha256 = hashlib.sha256(file_data).hexdigest()
778
- url = base_url + hash_sha256
779
- try:
780
- r = session.get(url, params=params, headers=headers, stream=True, timeout=(3.0, 15))
781
- except Exception as e:
782
- print(e)
783
- return default
784
- else:
785
- if not r.ok: return None
786
- json = r.json()
787
- if 'baseModel' not in json:
788
- civitai_not_exists_list.append(path)
789
- return default
790
- items = []
791
- items.append(" / ".join(json['trainedWords'])) # The words (prompts) used to trigger the model
792
- items.append(json['baseModel']) # Base model (SDXL1.0, Pony, ...)
793
- items.append(json['model']['name']) # The name of the model version
794
- items.append(f"https://civitai.com/models/{json['modelId']}") # The repo url for the model
795
- items.append(json['images'][0]['url']) # The url for a sample image
796
- loras_url_to_path_dict[path] = json['downloadUrl'] # The download url to get the model file for this specific version
797
- return items
798
-
799
-
800
- def search_lora_on_civitai(query: str, allow_model: list[str] = ["Pony", "SDXL 1.0"], limit: int = 100,
801
- sort: str = "Highest Rated", period: str = "AllTime", tag: str = ""):
802
- import requests
803
- from requests.adapters import HTTPAdapter
804
- from urllib3.util import Retry
805
- user_agent = get_user_agent()
806
- headers = {'User-Agent': user_agent, 'content-type': 'application/json'}
807
- base_url = 'https://civitai.com/api/v1/models'
808
- params = {'types': ['LORA'], 'sort': sort, 'period': period, 'limit': limit, 'nsfw': 'true'}
809
- if query: params["query"] = query
810
- if tag: params["tag"] = tag
811
- session = requests.Session()
812
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
813
- session.mount("https://", HTTPAdapter(max_retries=retries))
814
- try:
815
- r = session.get(base_url, params=params, headers=headers, stream=True, timeout=(3.0, 30))
816
- except Exception as e:
817
- print(e)
818
- return None
819
- else:
820
- if not r.ok: return None
821
- json = r.json()
822
- if 'items' not in json: return None
823
- items = []
824
- for j in json['items']:
825
- for model in j['modelVersions']:
826
- item = {}
827
- if model['baseModel'] not in set(allow_model): continue
828
- item['name'] = j['name']
829
- item['creator'] = j['creator']['username']
830
- item['tags'] = j['tags']
831
- item['model_name'] = model['name']
832
- item['base_model'] = model['baseModel']
833
- item['dl_url'] = model['downloadUrl']
834
- item['md'] = f'<img src="{model["images"][0]["url"]}" alt="thumbnail" width="150" height="240"><br>[LoRA Model URL](https://civitai.com/models/{j["id"]})'
835
- items.append(item)
836
- return items
837
-
838
-
839
- def search_civitai_lora(query, base_model, sort="Highest Rated", period="AllTime", tag=""):
840
- global civitai_lora_last_results
841
- items = search_lora_on_civitai(query, base_model, 100, sort, period, tag)
842
- if not items: return gr.update(choices=[("", "")], value="", visible=False),\
843
- gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
844
- civitai_lora_last_results = {}
845
- choices = []
846
- for item in items:
847
- base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
848
- name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
849
- value = item['dl_url']
850
- choices.append((name, value))
851
- civitai_lora_last_results[value] = item
852
- if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
853
- gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
854
- result = civitai_lora_last_results.get(choices[0][1], "None")
855
- md = result['md'] if result else ""
856
- return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
857
- gr.update(visible=True), gr.update(visible=True)
858
-
859
-
860
- def select_civitai_lora(search_result):
861
- if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
862
- result = civitai_lora_last_results.get(search_result, "None")
863
- md = result['md'] if result else ""
864
- return gr.update(value=search_result), gr.update(value=md, visible=True)
865
-
866
-
867
- LORA_BASE_MODEL_DICT = {
868
- "diffusers:StableDiffusionPipeline": ["SD 1.5"],
869
- "diffusers:StableDiffusionXLPipeline": ["Pony", "SDXL 1.0"],
870
- "diffusers:FluxPipeline": ["Flux.1 D", "Flux.1 S"],
871
- }
872
-
873
-
874
- def get_lora_base_model(model_name: str):
875
- api = HfApi(token=HF_TOKEN)
876
- default = ["Pony", "SDXL 1.0"]
877
- try:
878
- model = api.model_info(repo_id=model_name, timeout=5.0)
879
- tags = model.tags
880
- for tag in tags:
881
- if tag in LORA_BASE_MODEL_DICT.keys(): return LORA_BASE_MODEL_DICT.get(tag, default)
882
- except Exception:
883
- return default
884
- return default
885
-
886
-
887
- def find_similar_lora(q: str, model_name: str):
888
- from rapidfuzz.process import extractOne
889
- from rapidfuzz.utils import default_process
890
- query = to_lora_key(q)
891
- print(f"Finding <lora:{query}:...>...")
892
- keys = list(private_lora_dict.keys())
893
- values = [x[2] for x in list(private_lora_dict.values())]
894
- s = default_process(query)
895
- e1 = extractOne(s, keys + values, processor=default_process, score_cutoff=80.0)
896
- key = ""
897
- if e1:
898
- e = e1[0]
899
- if e in set(keys): key = e
900
- elif e in set(values): key = keys[values.index(e)]
901
- if key:
902
- path = to_lora_path(key)
903
- new_path = to_lora_path(query)
904
- if not Path(path).exists():
905
- if not Path(new_path).exists(): download_private_file_from_somewhere(path, True)
906
- if Path(path).exists() and copy_lora(path, new_path): return new_path
907
- print(f"Finding <lora:{query}:...> on Civitai...")
908
- civitai_query = Path(query).stem if Path(query).is_file() else query
909
- civitai_query = civitai_query.replace("_", " ").replace("-", " ")
910
- base_model = get_lora_base_model(model_name)
911
- items = search_lora_on_civitai(civitai_query, base_model, 1)
912
- if items:
913
- item = items[0]
914
- path = download_lora(item['dl_url'])
915
- new_path = query if Path(query).is_file() else to_lora_path(query)
916
- if path and copy_lora(path, new_path): return new_path
917
- return None
918
-
919
-
920
- def change_interface_mode(mode: str):
921
- if mode == "Fast":
922
- return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
923
- gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
924
- gr.update(visible=True), gr.update(value="Fast")
925
- elif mode == "Simple": # t2i mode
926
- return gr.update(open=True), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
927
- gr.update(visible=True), gr.update(open=False), gr.update(visible=False), gr.update(open=True),\
928
- gr.update(visible=False), gr.update(value="Standard")
929
- elif mode == "LoRA": # t2i LoRA mode
930
- return gr.update(open=True), gr.update(visible=True), gr.update(open=True), gr.update(open=False),\
931
- gr.update(visible=True), gr.update(open=True), gr.update(visible=True), gr.update(open=False),\
932
- gr.update(visible=False), gr.update(value="Standard")
933
- else: # Standard
934
- return gr.update(open=False), gr.update(visible=True), gr.update(open=False), gr.update(open=False),\
935
- gr.update(visible=True), gr.update(open=False), gr.update(visible=True), gr.update(open=False),\
936
- gr.update(visible=True), gr.update(value="Standard")
937
-
938
-
939
- quality_prompt_list = [
940
- {
941
- "name": "None",
942
- "prompt": "",
943
- "negative_prompt": "lowres",
944
- },
945
- {
946
- "name": "Animagine Common",
947
- "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
948
- "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
949
- },
950
- {
951
- "name": "Pony Anime Common",
952
- "prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
953
- "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
954
- },
955
- {
956
- "name": "Pony Common",
957
- "prompt": "source_anime, score_9, score_8_up, score_7_up",
958
- "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
959
- },
960
- {
961
- "name": "Animagine Standard v3.0",
962
- "prompt": "masterpiece, best quality",
963
- "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
964
- },
965
- {
966
- "name": "Animagine Standard v3.1",
967
- "prompt": "masterpiece, best quality, very aesthetic, absurdres",
968
- "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
969
- },
970
- {
971
- "name": "Animagine Light v3.1",
972
- "prompt": "(masterpiece), best quality, very aesthetic, perfect face",
973
- "negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
974
- },
975
- {
976
- "name": "Animagine Heavy v3.1",
977
- "prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
978
- "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
979
- },
980
- ]
981
-
982
-
983
- style_list = [
984
- {
985
- "name": "None",
986
- "prompt": "",
987
- "negative_prompt": "",
988
- },
989
- {
990
- "name": "Cinematic",
991
- "prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
992
- "negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
993
- },
994
- {
995
- "name": "Photographic",
996
- "prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
997
- "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
998
- },
999
- {
1000
- "name": "Anime",
1001
- "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed",
1002
- "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
1003
- },
1004
- {
1005
- "name": "Manga",
1006
- "prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
1007
- "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
1008
- },
1009
- {
1010
- "name": "Digital Art",
1011
- "prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
1012
- "negative_prompt": "photo, photorealistic, realism, ugly",
1013
- },
1014
- {
1015
- "name": "Pixel art",
1016
- "prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
1017
- "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
1018
- },
1019
- {
1020
- "name": "Fantasy art",
1021
- "prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
1022
- "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
1023
- },
1024
- {
1025
- "name": "Neonpunk",
1026
- "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
1027
- "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
1028
- },
1029
- {
1030
- "name": "3D Model",
1031
- "prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
1032
- "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
1033
- },
1034
- ]
1035
-
1036
-
1037
- optimization_list = {
1038
- "None": [28, 7., 'Euler a', False, 'None', 1.],
1039
- "Default": [28, 7., 'Euler a', False, 'None', 1.],
1040
- "SPO": [28, 7., 'Euler a', True, 'loras/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors', 1.],
1041
- "DPO": [28, 7., 'Euler a', True, 'loras/sdxl-DPO-LoRA.safetensors', 1.],
1042
- "DPO Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_dpo_turbo_lora_v1-128dim.safetensors', 1.],
1043
- "SDXL Turbo": [8, 2.5, 'LCM', True, 'loras/sd_xl_turbo_lora_v1.safetensors', 1.],
1044
- "Hyper-SDXL 12step": [12, 5., 'TCD', True, 'loras/Hyper-SDXL-12steps-CFG-lora.safetensors', 1.],
1045
- "Hyper-SDXL 8step": [8, 5., 'TCD', True, 'loras/Hyper-SDXL-8steps-CFG-lora.safetensors', 1.],
1046
- "Hyper-SDXL 4step": [4, 0, 'TCD', True, 'loras/Hyper-SDXL-4steps-lora.safetensors', 1.],
1047
- "Hyper-SDXL 2step": [2, 0, 'TCD', True, 'loras/Hyper-SDXL-2steps-lora.safetensors', 1.],
1048
- "Hyper-SDXL 1step": [1, 0, 'TCD', True, 'loras/Hyper-SDXL-1steps-lora.safetensors', 1.],
1049
- "PCM 16step": [16, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_16step_converted.safetensors', 1.],
1050
- "PCM 8step": [8, 4., 'Euler a trailing', True, 'loras/pcm_sdxl_normalcfg_8step_converted.safetensors', 1.],
1051
- "PCM 4step": [4, 2., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_4step_converted.safetensors', 1.],
1052
- "PCM 2step": [2, 1., 'Euler a trailing', True, 'loras/pcm_sdxl_smallcfg_2step_converted.safetensors', 1.],
1053
- }
1054
-
1055
-
1056
- def set_optimization(opt, steps_gui, cfg_gui, sampler_gui, clip_skip_gui, lora_gui, lora_scale_gui):
1057
- if not opt in list(optimization_list.keys()): opt = "None"
1058
- def_steps_gui = 28
1059
- def_cfg_gui = 7.
1060
- steps = optimization_list.get(opt, "None")[0]
1061
- cfg = optimization_list.get(opt, "None")[1]
1062
- sampler = optimization_list.get(opt, "None")[2]
1063
- clip_skip = optimization_list.get(opt, "None")[3]
1064
- lora = optimization_list.get(opt, "None")[4]
1065
- lora_scale = optimization_list.get(opt, "None")[5]
1066
- if opt == "None":
1067
- steps = max(steps_gui, def_steps_gui)
1068
- cfg = max(cfg_gui, def_cfg_gui)
1069
- clip_skip = clip_skip_gui
1070
- elif opt == "SPO" or opt == "DPO":
1071
- steps = max(steps_gui, def_steps_gui)
1072
- cfg = max(cfg_gui, def_cfg_gui)
1073
-
1074
- return gr.update(value=steps), gr.update(value=cfg), gr.update(value=sampler),\
1075
- gr.update(value=clip_skip), gr.update(value=lora), gr.update(value=lora_scale),
1076
-
1077
-
1078
- # [sampler_gui, steps_gui, cfg_gui, clip_skip_gui, img_width_gui, img_height_gui, optimization_gui]
1079
- preset_sampler_setting = {
1080
- "None": ["Euler a", 28, 7., True, 1024, 1024, "None"],
1081
- "Anime 3:4 Fast": ["LCM", 8, 2.5, True, 896, 1152, "DPO Turbo"],
1082
- "Anime 3:4 Standard": ["Euler a", 28, 7., True, 896, 1152, "None"],
1083
- "Anime 3:4 Heavy": ["Euler a", 40, 7., True, 896, 1152, "None"],
1084
- "Anime 1:1 Fast": ["LCM", 8, 2.5, True, 1024, 1024, "DPO Turbo"],
1085
- "Anime 1:1 Standard": ["Euler a", 28, 7., True, 1024, 1024, "None"],
1086
- "Anime 1:1 Heavy": ["Euler a", 40, 7., True, 1024, 1024, "None"],
1087
- "Photo 3:4 Fast": ["LCM", 8, 2.5, False, 896, 1152, "DPO Turbo"],
1088
- "Photo 3:4 Standard": ["DPM++ 2M Karras", 28, 7., False, 896, 1152, "None"],
1089
- "Photo 3:4 Heavy": ["DPM++ 2M Karras", 40, 7., False, 896, 1152, "None"],
1090
- "Photo 1:1 Fast": ["LCM", 8, 2.5, False, 1024, 1024, "DPO Turbo"],
1091
- "Photo 1:1 Standard": ["DPM++ 2M Karras", 28, 7., False, 1024, 1024, "None"],
1092
- "Photo 1:1 Heavy": ["DPM++ 2M Karras", 40, 7., False, 1024, 1024, "None"],
1093
- }
1094
-
1095
-
1096
- def set_sampler_settings(sampler_setting):
1097
- if not sampler_setting in list(preset_sampler_setting.keys()) or sampler_setting == "None":
1098
- return gr.update(value="Euler a"), gr.update(value=28), gr.update(value=7.), gr.update(value=True),\
1099
- gr.update(value=1024), gr.update(value=1024), gr.update(value="None")
1100
- v = preset_sampler_setting.get(sampler_setting, ["Euler a", 28, 7., True, 1024, 1024])
1101
- # sampler, steps, cfg, clip_skip, width, height, optimization
1102
- return gr.update(value=v[0]), gr.update(value=v[1]), gr.update(value=v[2]), gr.update(value=v[3]),\
1103
- gr.update(value=v[4]), gr.update(value=v[5]), gr.update(value=v[6])
1104
-
1105
-
1106
- preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
1107
- preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}
1108
-
1109
-
1110
- def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None", type: str = "Auto"):
1111
- def to_list(s):
1112
- return [x.strip() for x in s.split(",") if not s == ""]
1113
-
1114
- def list_sub(a, b):
1115
- return [e for e in a if e not in b]
1116
-
1117
- def list_uniq(l):
1118
- return sorted(set(l), key=l.index)
1119
-
1120
- animagine_ps = to_list("anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
1121
- animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
1122
- pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
1123
- pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
1124
- prompts = to_list(prompt)
1125
- neg_prompts = to_list(neg_prompt)
1126
-
1127
- all_styles_ps = []
1128
- all_styles_nps = []
1129
- for d in style_list:
1130
- all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
1131
- all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))
1132
-
1133
- all_quality_ps = []
1134
- all_quality_nps = []
1135
- for d in quality_prompt_list:
1136
- all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
1137
- all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))
1138
-
1139
- quality_ps = to_list(preset_quality[quality_key][0])
1140
- quality_nps = to_list(preset_quality[quality_key][1])
1141
- styles_ps = to_list(preset_styles[styles_key][0])
1142
- styles_nps = to_list(preset_styles[styles_key][1])
1143
-
1144
- prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
1145
- neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)
1146
-
1147
- last_empty_p = [""] if not prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
1148
- last_empty_np = [""] if not neg_prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
1149
-
1150
- if type == "Animagine":
1151
- prompts = prompts + animagine_ps
1152
- neg_prompts = neg_prompts + animagine_nps
1153
- elif type == "Pony":
1154
- prompts = prompts + pony_ps
1155
- neg_prompts = neg_prompts + pony_nps
1156
-
1157
- prompts = prompts + styles_ps + quality_ps
1158
- neg_prompts = neg_prompts + styles_nps + quality_nps
1159
-
1160
- prompt = ", ".join(list_uniq(prompts) + last_empty_p)
1161
- neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
1162
-
1163
- return gr.update(value=prompt), gr.update(value=neg_prompt), gr.update(value=type)
1164
-
1165
-
1166
- def set_quick_presets(genre:str = "None", type:str = "Auto", speed:str = "None", aspect:str = "None"):
1167
- quality = "None"
1168
- style = "None"
1169
- sampler = "None"
1170
- opt = "None"
1171
-
1172
- if genre == "Anime":
1173
- if type != "None" and type != "Auto": style = "Anime"
1174
- if aspect == "1:1":
1175
- if speed == "Heavy":
1176
- sampler = "Anime 1:1 Heavy"
1177
- elif speed == "Fast":
1178
- sampler = "Anime 1:1 Fast"
1179
- else:
1180
- sampler = "Anime 1:1 Standard"
1181
- elif aspect == "3:4":
1182
- if speed == "Heavy":
1183
- sampler = "Anime 3:4 Heavy"
1184
- elif speed == "Fast":
1185
- sampler = "Anime 3:4 Fast"
1186
- else:
1187
- sampler = "Anime 3:4 Standard"
1188
- if type == "Pony":
1189
- quality = "Pony Anime Common"
1190
- elif type == "Animagine":
1191
- quality = "Animagine Common"
1192
- else:
1193
- quality = "None"
1194
- elif genre == "Photo":
1195
- if type != "None" and type != "Auto": style = "Photographic"
1196
- if aspect == "1:1":
1197
- if speed == "Heavy":
1198
- sampler = "Photo 1:1 Heavy"
1199
- elif speed == "Fast":
1200
- sampler = "Photo 1:1 Fast"
1201
- else:
1202
- sampler = "Photo 1:1 Standard"
1203
- elif aspect == "3:4":
1204
- if speed == "Heavy":
1205
- sampler = "Photo 3:4 Heavy"
1206
- elif speed == "Fast":
1207
- sampler = "Photo 3:4 Fast"
1208
- else:
1209
- sampler = "Photo 3:4 Standard"
1210
- if type == "Pony":
1211
- quality = "Pony Common"
1212
- else:
1213
- quality = "None"
1214
-
1215
- if speed == "Fast":
1216
- opt = "DPO Turbo"
1217
- if genre == "Anime" and type != "Pony" and type != "Auto": quality = "Animagine Light v3.1"
1218
-
1219
- return gr.update(value=quality), gr.update(value=style), gr.update(value=sampler), gr.update(value=opt), gr.update(value=type)
1220
-
1221
-
1222
- textual_inversion_dict = {}
1223
- try:
1224
- with open('textual_inversion_dict.json', encoding='utf-8') as f:
1225
- textual_inversion_dict = json.load(f)
1226
- except Exception:
1227
- pass
1228
- textual_inversion_file_token_list = []
1229
-
1230
-
1231
- def get_tupled_embed_list(embed_list):
1232
- global textual_inversion_file_list
1233
- tupled_list = []
1234
- for file in embed_list:
1235
- token = textual_inversion_dict.get(Path(file).name, [Path(file).stem.replace(",",""), False])[0]
1236
- tupled_list.append((token, file))
1237
- textual_inversion_file_token_list.append(token)
1238
- return tupled_list
1239
-
1240
-
1241
- def set_textual_inversion_prompt(textual_inversion_gui, prompt_gui, neg_prompt_gui, prompt_syntax_gui):
1242
- ti_tags = list(textual_inversion_dict.values()) + textual_inversion_file_token_list
1243
- tags = prompt_gui.split(",") if prompt_gui else []
1244
- prompts = []
1245
- for tag in tags:
1246
- tag = str(tag).strip()
1247
- if tag and not tag in ti_tags:
1248
- prompts.append(tag)
1249
- ntags = neg_prompt_gui.split(",") if neg_prompt_gui else []
1250
- neg_prompts = []
1251
- for tag in ntags:
1252
- tag = str(tag).strip()
1253
- if tag and not tag in ti_tags:
1254
- neg_prompts.append(tag)
1255
- ti_prompts = []
1256
- ti_neg_prompts = []
1257
- for ti in textual_inversion_gui:
1258
- tokens = textual_inversion_dict.get(Path(ti).name, [Path(ti).stem.replace(",",""), False])
1259
- is_positive = tokens[1] == True or "positive" in Path(ti).parent.name
1260
- if is_positive: # positive prompt
1261
- ti_prompts.append(tokens[0])
1262
- else: # negative prompt (default)
1263
- ti_neg_prompts.append(tokens[0])
1264
- empty = [""]
1265
- prompt = ", ".join(prompts + ti_prompts + empty)
1266
- neg_prompt = ", ".join(neg_prompts + ti_neg_prompts + empty)
1267
- return gr.update(value=prompt), gr.update(value=neg_prompt),
1268
-
1269
-
1270
- def get_model_pipeline(repo_id: str):
1271
- from huggingface_hub import HfApi
1272
- api = HfApi(token=HF_TOKEN)
1273
- default = "StableDiffusionPipeline"
1274
- try:
1275
- if not is_repo_name(repo_id): return default
1276
- model = api.model_info(repo_id=repo_id, timeout=5.0)
1277
- except Exception:
1278
- return default
1279
- if model.private or model.gated: return default
1280
- tags = model.tags
1281
- if not 'diffusers' in tags: return default
1282
- if 'diffusers:FluxPipeline' in tags:
1283
- return "FluxPipeline"
1284
- if 'diffusers:StableDiffusionXLPipeline' in tags:
1285
- return "StableDiffusionXLPipeline"
1286
- elif 'diffusers:StableDiffusionPipeline' in tags:
1287
- return "StableDiffusionPipeline"
1288
- else:
1289
- return default
1290
-