John6666 commited on
Commit
56299fc
1 Parent(s): 9e18add

Upload 13 files

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. app.py +86 -97
  3. multit2i.py +90 -88
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: Text-to-Image
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  short_description: Text-to-Image
11
+ license: mit
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,71 +1,37 @@
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (
4
- load_models,
5
- infer_multi,
6
- infer_multi_random,
7
- save_gallery_images,
8
- change_model,
9
- get_model_info_md,
10
- loaded_models,
11
- get_positive_prefix,
12
- get_positive_suffix,
13
- get_negative_prefix,
14
- get_negative_suffix,
15
- get_recom_prompt_type,
16
- set_recom_prompt_preset,
17
- get_tag_type,
18
  )
19
  from tagger.tagger import (
20
- predict_tags_wd,
21
- remove_specific_prompt,
22
- convert_danbooru_to_e621_prompt,
23
- insert_recom_prompt,
24
- compose_prompt_to_copy,
25
  )
26
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
27
- from tagger.v2 import (
28
- V2_ALL_MODELS,
29
- v2_random_prompt,
30
- )
31
  from tagger.utils import (
32
- V2_ASPECT_RATIO_OPTIONS,
33
- V2_RATING_OPTIONS,
34
- V2_LENGTH_OPTIONS,
35
- V2_IDENTITY_OPTIONS,
36
  )
37
 
38
 
39
- load_models(models, 5)
40
- #load_models(models, 20) # Fetching 20 models at the same time. default: 5
41
-
42
 
43
  css = """
44
- #model_info { text-align: center; }
 
 
45
  """
46
 
47
- with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
48
  with gr.Column():
49
- with gr.Accordion("Advanced settings", open=False):
50
- with gr.Accordion("Recommended Prompt", open=False):
51
- recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
52
- with gr.Row():
53
- positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
54
- positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
55
- negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
56
- negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
57
- with gr.Accordion("Prompt Transformer", open=False):
58
- v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
59
- v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
60
- v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
61
- v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
62
- v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
63
- v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
64
- v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
65
- v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
66
- with gr.Accordion("Model", open=True):
67
- model_name = gr.Dropdown(label="Select Model", show_label=False, choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
68
- model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_id="model_info")
69
  with gr.Group():
70
  with gr.Accordion("Prompt from Image File", open=False):
71
  tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
@@ -82,59 +48,82 @@ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as demo:
82
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
83
  random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
84
  clear_prompt = gr.Button(value="Clear Prompt 🗑️", size="sm", scale=1)
85
- prompt = gr.Text(label="Prompt", lines=1, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
86
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Row():
88
  run_button = gr.Button("Generate Image", scale=6)
89
  random_button = gr.Button("Random Model 🎲", scale=3)
90
- image_num = gr.Number(label="Count", minimum=1, maximum=16, value=1, step=1, interactive=True, scale=1)
91
- results = gr.Gallery(label="Gallery", interactive=False, show_download_button=True, show_share_button=False,
92
- container=True, format="png", object_fit="contain")
93
- image_files = gr.Files(label="Download", interactive=False)
94
- clear_results = gr.Button("Clear Gallery / Download")
95
- examples = gr.Examples(
96
- examples = [
97
- ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
98
- ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
99
- ["kafuu chino, 1girl, solo"],
100
- ["1girl"],
101
- ["beautiful sunset"],
102
- ],
103
- inputs=[prompt],
104
- )
105
- gr.Markdown(
106
- f"""This demo was created in reference to the following demos.
107
- - [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood).
108
- - [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL).
109
- """
110
- )
111
- gr.DuplicateButton(value="Duplicate Space")
 
 
 
 
 
 
 
 
 
112
 
113
- model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)
114
- gr.on(
115
- triggers=[run_button.click, prompt.submit],
116
- fn=infer_multi,
117
- inputs=[prompt, neg_prompt, results, image_num, model_name,
118
- positive_prefix, positive_suffix, negative_prefix, negative_suffix],
119
- outputs=[results],
120
- queue=True,
121
- show_progress="full",
122
- show_api=True,
123
- ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
124
- gr.on(
125
- triggers=[random_button.click],
126
- fn=infer_multi_random,
127
- inputs=[prompt, neg_prompt, results, image_num,
128
- positive_prefix, positive_suffix, negative_prefix, negative_suffix],
129
- outputs=[results],
130
- queue=True,
131
- show_progress="full",
132
- show_api=True,
133
- ).success(save_gallery_images, [results], [results, image_files], queue=False, show_api=False)
134
- clear_prompt.click(lambda: (None, None, None), None, [prompt, v2_series, v2_character], queue=False, show_api=False)
135
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
136
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
137
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
 
138
  random_prompt.click(
139
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
140
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
 
1
  import gradio as gr
2
  from model import models
3
  from multit2i import (
4
+ load_models, infer_fn, infer_rand_fn, save_gallery,
5
+ change_model, warm_model, get_model_info_md, loaded_models,
6
+ get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
7
+ get_recom_prompt_type, set_recom_prompt_preset, get_tag_type,
 
 
 
 
 
 
 
 
 
 
8
  )
9
  from tagger.tagger import (
10
+ predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
11
+ insert_recom_prompt, compose_prompt_to_copy,
 
 
 
12
  )
13
  from tagger.fl2sd3longcap import predict_tags_fl2_sd3
14
+ from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
 
 
 
15
  from tagger.utils import (
16
+ V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
17
+ V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS,
 
 
18
  )
19
 
20
 
21
+ max_images = 8
22
+ load_models(models)
 
23
 
24
  css = """
25
+ .model_info { text-align: center; }
26
+ .output { width=112px; height=112px; !important; }
27
+ .gallery { width=100%; min_height=768px; !important; }
28
  """
29
 
30
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
31
  with gr.Column():
32
+ with gr.Group():
33
+ model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
34
+ model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with gr.Group():
36
  with gr.Accordion("Prompt from Image File", open=False):
37
  tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
 
48
  v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2)
49
  random_prompt = gr.Button(value="Extend Prompt 🎲", size="sm", scale=1)
50
  clear_prompt = gr.Button(value="Clear Prompt 🗑️", size="sm", scale=1)
51
+ prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
52
  neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
53
+ with gr.Accordion("Recommended Prompt", open=False):
54
+ recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
55
+ with gr.Row():
56
+ positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
57
+ positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
58
+ negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[], visible=False)
59
+ negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"], visible=False)
60
+ with gr.Accordion("Prompt Transformer", open=False):
61
+ v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
62
+ v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
63
+ v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
64
+ v2_identity = gr.Radio(label="Keep identity", info="How strictly to keep the identity of the character or subject. If you specify the detail of subject in the prompt, you should choose `strict`. Otherwise, choose `none` or `lax`. `none` is very creative but sometimes ignores the input prompt.", choices=list(V2_IDENTITY_OPTIONS), value="lax")
65
+ v2_ban_tags = gr.Textbox(label="Ban tags", info="Tags to ban from the output.", placeholder="alternate costumen, ...", value="censored")
66
+ v2_tag_type = gr.Radio(label="Tag Type", info="danbooru for common, e621 for Pony.", choices=["danbooru", "e621"], value="danbooru", visible=False)
67
+ v2_model = gr.Dropdown(label="Model", choices=list(V2_ALL_MODELS.keys()), value=list(V2_ALL_MODELS.keys())[0])
68
+ v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
69
+ image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=1)
70
  with gr.Row():
71
  run_button = gr.Button("Generate Image", scale=6)
72
  random_button = gr.Button("Random Model 🎲", scale=3)
73
+ stop_button = gr.Button('Stop', interactive=False, scale=1)
74
+ with gr.Column():
75
+ with gr.Group():
76
+ with gr.Row():
77
+ output = [gr.Image(label='', elem_classes="output", type="filepath", format=".png",
78
+ show_download_button=True, show_share_button=False, show_label=False,
79
+ interactive=False, min_width=80, visible=True) for _ in range(max_images)]
80
+ with gr.Group():
81
+ results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
82
+ container=True, format="png", object_fit="cover", columns=2, rows=2)
83
+ image_files = gr.Files(label="Download", interactive=False)
84
+ clear_results = gr.Button("Clear Gallery / Download 🗑️")
85
+ with gr.Column():
86
+ examples = gr.Examples(
87
+ examples = [
88
+ ["souryuu asuka langley, 1girl, neon genesis evangelion, plugsuit, pilot suit, red bodysuit, sitting, crossing legs, black eye patch, cat hat, throne, symmetrical, looking down, from bottom, looking at viewer, outdoors"],
89
+ ["sailor moon, magical girl transformation, sparkles and ribbons, soft pastel colors, crescent moon motif, starry night sky background, shoujo manga style"],
90
+ ["kafuu chino, 1girl, solo"],
91
+ ["1girl"],
92
+ ["beautiful sunset"],
93
+ ],
94
+ inputs=[prompt],
95
+ )
96
+ gr.Markdown(
97
+ f"""This demo was created in reference to the following demos.<br>
98
+ [Nymbo/Flood](https://huggingface.co/spaces/Nymbo/Flood),
99
+ [Yntec/ToyWorldXL](https://huggingface.co/spaces/Yntec/ToyWorldXL),
100
+ [Yntec/Diffusion80XX](https://huggingface.co/spaces/Yntec/Diffusion80XX).
101
+ """
102
+ )
103
+ gr.DuplicateButton(value="Duplicate Space")
104
 
105
+ gr.on(triggers=[run_button.click, prompt.submit, random_button.click], fn=lambda: gr.update(interactive=True), inputs=None, outputs=stop_button, show_api=False)
106
+ model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)\
107
+ .success(warm_model, [model_name], None, queue=True, show_api=False)
108
+ for i, o in enumerate(output):
109
+ img_i = gr.Number(i, visible=False)
110
+ image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
111
+ gen_event = gr.on(triggers=[run_button.click, prompt.submit],
112
+ fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
113
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
114
+ outputs=[o], queue=True, show_api=False)
115
+ gen_event2 = gr.on(triggers=[random_button.click],
116
+ fn=lambda i, n, m, t1, t2, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, l1, l2, l3, l4) if (i < n) else None,
117
+ inputs=[img_i, image_num, model_name, prompt, neg_prompt, positive_prefix, positive_suffix, negative_prefix, negative_suffix],
118
+ outputs=[o], queue=True, show_api=False)
119
+ o.change(save_gallery, [o, results], [results, image_files], show_api=False)
120
+ stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[gen_event, gen_event2], show_api=False)
121
+
122
+ clear_prompt.click(lambda: None, None, [prompt], queue=False, show_api=False)
 
 
 
 
123
  clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
124
  recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
125
  [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
126
+
127
  random_prompt.click(
128
  v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
129
  v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
multit2i.py CHANGED
@@ -80,33 +80,36 @@ def get_t2i_model_info_dict(repo_id: str):
80
  return info
81
 
82
 
83
- def save_gallery_images(images, progress=gr.Progress(track_tqdm=True)):
 
84
  from datetime import datetime, timezone, timedelta
85
- progress(0, desc="Updating gallery...")
86
  dt_now = datetime.now(timezone(timedelta(hours=9)))
87
- basename = dt_now.strftime('%Y%m%d_%H%M%S_')
88
- i = 1
89
- if not images: return images
90
- output_images = []
91
- output_paths = []
92
- for image in images:
93
- filename = f'{image[1]}_{basename}{str(i)}.png'
94
- i += 1
95
- oldpath = Path(image[0])
96
- newpath = oldpath
97
- try:
98
- if oldpath.stem == "image" and oldpath.exists():
99
- newpath = oldpath.resolve().rename(Path(filename).resolve())
100
- except Exception as e:
101
- print(e)
102
- pass
103
- finally:
104
- output_paths.append(str(newpath))
105
- output_images.append((str(newpath), str(filename)))
106
- progress(1, desc="Gallery updated.")
107
- return gr.update(value=output_images), gr.update(value=output_paths)
108
 
109
 
 
 
110
  def load_from_model(model_name: str, hf_token: str = None):
111
  import httpx
112
  import huggingface_hub
@@ -122,7 +125,7 @@ def load_from_model(model_name: str, hf_token: str = None):
122
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
123
  )
124
  headers["X-Wait-For-Model"] = "true"
125
- client = huggingface_hub.InferenceClient(model=model_name, headers=headers, token=hf_token)
126
  inputs = gr.components.Textbox(label="Input")
127
  outputs = gr.components.Image(label="Output")
128
  fn = client.text_to_image
@@ -161,27 +164,9 @@ def load_model(model_name: str):
161
  return loaded_models[model_name]
162
 
163
 
164
- async def async_load_models(models: list, limit: int=5):
165
- sem = asyncio.Semaphore(limit)
166
- async def async_load_model(model: str):
167
- async with sem:
168
- try:
169
- return await asyncio.to_thread(load_model, model)
170
- except Exception as e:
171
- print(e)
172
- tasks = [asyncio.create_task(async_load_model(model)) for model in models]
173
- return await asyncio.gather(*tasks, return_exceptions=True)
174
-
175
-
176
- def load_models(models: list, limit: int=5):
177
- loop = asyncio.new_event_loop()
178
- try:
179
- loop.run_until_complete(async_load_models(models, limit))
180
- except Exception as e:
181
- print(e)
182
- pass
183
- finally:
184
- loop.close()
185
 
186
 
187
  positive_prefix = {
@@ -295,56 +280,73 @@ def change_model(model_name: str):
295
  return get_model_info_md(model_name)
296
 
297
 
298
- def infer(prompt: str, neg_prompt: str, model_name: str):
299
- from PIL import Image
 
 
 
 
 
 
 
 
 
300
  import random
301
- seed = ""
302
  rand = random.randint(1, 500)
303
  for i in range(rand):
304
- seed += " "
305
- caption = model_name.split("/")[-1]
 
 
 
306
  try:
307
- model = load_model(model_name)
308
- if not model: return (Image.Image(), None)
309
- image_path = model(prompt + seed)
310
- image = Image.open(image_path).convert('RGBA')
311
- except Exception as e:
312
  print(e)
313
- return (Image.Image(), None)
314
- return (image, caption)
315
-
316
-
317
- async def infer_multi(prompt: str, neg_prompt: str, results: list, image_num: float, model_name: str,
318
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
319
- from tqdm.asyncio import tqdm_asyncio
320
- image_num = int(image_num)
321
- images = results if results else []
322
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
323
- tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for i in range(image_num)]
324
- #results = await asyncio.gather(*tasks, return_exceptions=True)
325
- results = await tqdm_asyncio.gather(*tasks)
326
- if not results: results = []
327
- for result in results:
328
  with lock:
329
- if result and result[1]: images.append(result)
330
- yield images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
 
333
- async def infer_multi_random(prompt: str, neg_prompt: str, results: list, image_num: float,
334
- pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = [], progress=gr.Progress(track_tqdm=True)):
335
- from tqdm.asyncio import tqdm_asyncio
336
  import random
337
- image_num = int(image_num)
338
- images = results if results else []
339
  random.seed()
340
- model_names = random.choices(list(loaded_models.keys()), k = image_num)
341
- prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
342
- tasks = [asyncio.to_thread(infer, prompt, neg_prompt, model_name) for model_name in model_names]
343
- #results = await asyncio.gather(*tasks, return_exceptions=True)
344
- results = await tqdm_asyncio.gather(*tasks)
345
- if not results: results = []
346
- for result in results:
347
- with lock:
348
- if result and result[1]: images.append(result)
349
- yield images
350
-
 
 
80
  return info
81
 
82
 
83
+ def rename_image(image_path: str | None, model_name: str):
84
+ from PIL import Image
85
  from datetime import datetime, timezone, timedelta
86
+ if image_path is None: return None
87
  dt_now = datetime.now(timezone(timedelta(hours=9)))
88
+ filename = f"{model_name.split('/')[-1]}_{dt_now.strftime('%Y%m%d_%H%M%S')}.png"
89
+ try:
90
+ if Path(image_path).exists():
91
+ png_path = "image.png"
92
+ Image.open(image_path).convert('RGBA').save(png_path, "PNG")
93
+ new_path = str(Path(png_path).resolve().rename(Path(filename).resolve()))
94
+ return new_path
95
+ else:
96
+ return None
97
+ except Exception as e:
98
+ print(e)
99
+ return None
100
+
101
+
102
+ def save_gallery(image_path: str | None, images: list[tuple] | None):
103
+ if images is None: images = []
104
+ files = [i[0] for i in images]
105
+ if image_path is None: return images, files
106
+ files.insert(0, str(image_path))
107
+ images.insert(0, (str(image_path), Path(image_path).stem))
108
+ return images, files
109
 
110
 
111
+ # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
112
+ # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
113
  def load_from_model(model_name: str, hf_token: str = None):
114
  import httpx
115
  import huggingface_hub
 
125
  f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
126
  )
127
  headers["X-Wait-For-Model"] = "true"
128
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers, token=hf_token, timeout=600)
129
  inputs = gr.components.Textbox(label="Input")
130
  outputs = gr.components.Image(label="Output")
131
  fn = client.text_to_image
 
164
  return loaded_models[model_name]
165
 
166
 
167
+ def load_models(models: list):
168
+ for model in models:
169
+ load_model(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  positive_prefix = {
 
280
  return get_model_info_md(model_name)
281
 
282
 
283
+ def warm_model(model_name: str):
284
+ model = load_model(model_name)
285
+ if model:
286
+ try:
287
+ print(f"Warming model: {model_name}")
288
+ model(" ")
289
+ except Exception as e:
290
+ print(e)
291
+
292
+
293
+ async def infer(model_name: str, prompt: str, neg_prompt: str, timeout: float):
294
  import random
295
+ noise = ""
296
  rand = random.randint(1, 500)
297
  for i in range(rand):
298
+ noise += " "
299
+ model = load_model(model_name)
300
+ if not model: return None
301
+ task = asyncio.create_task(asyncio.to_thread(model, f'{prompt} {noise}'))
302
+ await asyncio.sleep(0)
303
  try:
304
+ result = await asyncio.wait_for(task, timeout=timeout)
305
+ except (Exception, asyncio.TimeoutError) as e:
 
 
 
306
  print(e)
307
+ print(f"Task timed out: {model_name}")
308
+ if not task.done(): task.cancel()
309
+ result = None
310
+ if task.done() and result is not None:
 
 
 
 
 
 
 
 
 
 
 
311
  with lock:
312
+ image = rename_image(result, model_name)
313
+ return image
314
+ return None
315
+
316
+
317
+ infer_timeout = 300
318
+ def infer_fn(model_name: str, prompt: str, neg_prompt: str,
319
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
320
+ if model_name == 'NA':
321
+ return None
322
+ try:
323
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
324
+ loop = asyncio.new_event_loop()
325
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
326
+ except (Exception, asyncio.CancelledError) as e:
327
+ print(e)
328
+ print(f"Task aborted: {model_name}")
329
+ result = None
330
+ finally:
331
+ loop.close()
332
+ return result
333
 
334
 
335
+ def infer_rand_fn(model_name_dummy: str, prompt: str, neg_prompt: str,
336
+ pos_pre: list = [], pos_suf: list = [], neg_pre: list = [], neg_suf: list = []):
 
337
  import random
338
+ if model_name_dummy == 'NA':
339
+ return None
340
  random.seed()
341
+ model_name = random.choice(list(loaded_models.keys()))
342
+ try:
343
+ prompt, neg_prompt = recom_prompt(prompt, neg_prompt, pos_pre, pos_suf, neg_pre, neg_suf)
344
+ loop = asyncio.new_event_loop()
345
+ result = loop.run_until_complete(infer(model_name, prompt, neg_prompt, infer_timeout))
346
+ except (Exception, asyncio.CancelledError) as e:
347
+ print(e)
348
+ print(f"Task aborted: {model_name}")
349
+ result = None
350
+ finally:
351
+ loop.close()
352
+ return result