mantrakp commited on
Commit
42ae52a
β€’
1 Parent(s): 578deb0

Add text, audio, and video tabs

Browse files

This commit adds three new files: text_tab.py, audio_tab.py, and video_tab.py. These files contain the initial implementation of the text, audio, and video tabs for the application. Each tab consists of a simple label that says "Coming soon...". This is the first step towards implementing these tabs in the application.

app.py CHANGED
@@ -1,816 +1,32 @@
1
- # Testing one file gradio app for zero gpu spaces not working as expected.
2
- # Check here for the issue: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/106#66e278a396acd45223e0d00b
3
-
4
- import os
5
- import gc
6
- import json
7
- import random
8
- from typing import List, Optional
9
-
10
- import spaces
11
  import gradio as gr
12
- from huggingface_hub import ModelCard
13
- import torch
14
- from pydantic import BaseModel
15
- from PIL import Image
16
- from diffusers import (
17
- AutoPipelineForText2Image,
18
- AutoPipelineForImage2Image,
19
- AutoPipelineForInpainting,
20
- DiffusionPipeline,
21
- AutoencoderKL,
22
- FluxControlNetModel,
23
- FluxMultiControlNetModel,
24
- )
25
- from huggingface_hub import hf_hub_download
26
- from diffusers.schedulers import *
27
- from huggingface_hub import hf_hub_download
28
- from controlnet_aux.processor import Processor
29
- from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
30
-
31
-
32
- # Initialize System
33
- os.system("pip install --upgrade pip")
34
-
35
-
36
- def load_sd():
37
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
-
40
- # Models
41
- models = [
42
- {
43
- "repo_id": "black-forest-labs/FLUX.1-dev",
44
- "loader": "flux",
45
- "compute_type": torch.bfloat16,
46
- }
47
- ]
48
-
49
- for model in models:
50
- try:
51
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
52
- model['repo_id'],
53
- vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device),
54
- torch_dtype = model['compute_type'],
55
- safety_checker = None,
56
- variant = "fp16"
57
- ).to(device)
58
- except:
59
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
60
- model['repo_id'],
61
- vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
62
- torch_dtype = model['compute_type'],
63
- safety_checker = None
64
- ).to(device)
65
-
66
- model["pipeline"].enable_model_cpu_offload()
67
-
68
-
69
- # VAE n Refiner
70
- flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
71
- sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
72
- refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
73
- refiner.enable_model_cpu_offload()
74
-
75
-
76
- # ControlNet
77
- controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
78
- "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
79
- torch_dtype=torch.bfloat16
80
- ).to(device)])
81
-
82
- return device, models, flux_vae, sdxl_vae, refiner, controlnet
83
-
84
-
85
- device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
86
-
87
-
88
- # Models
89
- class ControlNetReq(BaseModel):
90
- controlnets: List[str] # ["canny", "tile", "depth"]
91
- control_images: List[Image.Image]
92
- controlnet_conditioning_scale: List[float]
93
-
94
- class Config:
95
- arbitrary_types_allowed=True
96
-
97
-
98
- class FluxReq(BaseModel):
99
- model: str = ""
100
- prompt: str = ""
101
- fast_generation: Optional[bool] = True
102
- loras: Optional[list] = []
103
- resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
104
- scheduler: Optional[str] = "euler_fl"
105
- height: int = 1024
106
- width: int = 1024
107
- num_images_per_prompt: int = 1
108
- num_inference_steps: int = 8
109
- guidance_scale: float = 3.5
110
- seed: Optional[int] = 0
111
- refiner: bool = False
112
- vae: bool = True
113
- controlnet_config: Optional[ControlNetReq] = None
114
-
115
- class Config:
116
- arbitrary_types_allowed=True
117
-
118
-
119
- class FluxImg2ImgReq(FluxReq):
120
- image: Image.Image
121
- strength: float = 1.0
122
-
123
- class Config:
124
- arbitrary_types_allowed=True
125
-
126
-
127
- class FluxInpaintReq(FluxImg2ImgReq):
128
- mask_image: Image.Image
129
-
130
- class Config:
131
- arbitrary_types_allowed=True
132
-
133
-
134
- # Helper Functions
135
- def get_control_mode(controlnet_config: ControlNetReq):
136
- control_mode = []
137
- layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
138
-
139
- for c in controlnet_config.controlnets:
140
- if c in layers:
141
- control_mode.append(layers.index(c))
142
-
143
- return control_mode
144
-
145
-
146
- def get_pipe(request: FluxReq | FluxImg2ImgReq | FluxInpaintReq):
147
- for m in models:
148
- if m['repo_id'] == request.model:
149
- pipe_args = {
150
- "pipeline": m['pipeline'],
151
- }
152
-
153
-
154
- # Set ControlNet config
155
- if request.controlnet_config:
156
- pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
157
- pipe_args["controlnet"] = [controlnet]
158
-
159
-
160
- # Choose Pipeline Mode
161
- if isinstance(request, FluxReq):
162
- pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
163
- elif isinstance(request, FluxImg2ImgReq):
164
- pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
165
- elif isinstance(request, FluxInpaintReq):
166
- pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
167
-
168
-
169
- # Enable or Disable Refiner
170
- if request.vae:
171
- pipe_args["pipeline"].vae = flux_vae
172
- elif not request.vae:
173
- pipe_args["pipeline"].vae = None
174
-
175
-
176
- # Set Scheduler
177
- pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
178
-
179
-
180
- # Set Loras
181
- if request.loras:
182
- for i, lora in enumerate(request.loras):
183
- pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}")
184
- adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
185
- adapter_weights = [lora['weight'] for lora in request.loras]
186
-
187
- if request.fast_generation:
188
- hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
189
- hyper_weight = 0.125
190
- pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
191
- adapter_names.append("hyper_lora")
192
- adapter_weights.append(hyper_weight)
193
-
194
- pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
195
-
196
- return pipe_args
197
-
198
-
199
- def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
200
- for image in images:
201
- if resize_mode == "resize_only":
202
- image = image.resize((width, height))
203
- elif resize_mode == "crop_and_resize":
204
- image = image.crop((0, 0, width, height))
205
- elif resize_mode == "resize_and_fill":
206
- image = image.resize((width, height), Image.Resampling.LANCZOS)
207
-
208
- return images
209
-
210
-
211
- def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str):
212
- response_images = []
213
- control_images = resize_images(controlnet_config.control_images, height, width, resize_mode)
214
- for controlnet, image in zip(controlnet_config.controlnets, control_images):
215
- if controlnet == "canny":
216
- processor = Processor('canny')
217
- elif controlnet == "depth":
218
- processor = Processor('depth_midas')
219
- elif controlnet == "pose":
220
- processor = Processor('openpose_full')
221
- else:
222
- raise ValueError(f"Invalid Controlnet: {controlnet}")
223
-
224
- response_images.append(processor(image, to_pil=True))
225
-
226
- return response_images
227
-
228
-
229
- def get_prompt_attention(pipeline, prompt):
230
- return get_weighted_text_embeddings_flux1(pipeline, prompt)
231
-
232
-
233
- def cleanup(pipeline, loras = None):
234
- if loras:
235
- pipeline.unload_lora_weights()
236
- gc.collect()
237
- torch.cuda.empty_cache()
238
-
239
 
240
- # Gen Function
241
- def gen_img(request: FluxReq | FluxImg2ImgReq | FluxInpaintReq):
242
- pipe_args = get_pipe(request)
243
- pipeline = pipe_args["pipeline"]
244
- try:
245
- positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
246
-
247
- # Common Args
248
- args = {
249
- 'prompt_embeds': positive_prompt_embeds,
250
- 'pooled_prompt_embeds': positive_prompt_pooled,
251
- 'height': request.height,
252
- 'width': request.width,
253
- 'num_images_per_prompt': request.num_images_per_prompt,
254
- 'num_inference_steps': request.num_inference_steps,
255
- 'guidance_scale': request.guidance_scale,
256
- 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
257
- }
258
-
259
- if request.controlnet_config:
260
- args['control_mode'] = get_control_mode(request.controlnet_config)
261
- args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
262
- args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
263
-
264
- if isinstance(request, (FluxImg2ImgReq, FluxInpaintReq)):
265
- args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
266
- args['strength'] = request.strength
267
-
268
- if isinstance(request, FluxInpaintReq):
269
- args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
270
-
271
- # Generate
272
- images = pipeline(**args).images
273
-
274
- # Refiner
275
- if request.refiner:
276
- images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
277
-
278
- cleanup(pipeline, request.loras)
279
-
280
- return images
281
- except Exception as e:
282
- cleanup(pipeline, request.loras)
283
- raise gr.Error(f"Error: {e}")
284
-
285
-
286
-
287
- # CSS
288
- css = """
289
- @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
290
- body {
291
- font-family: 'Poppins', sans-serif !important;
292
- }
293
- .center-content {
294
- text-align: center;
295
- max-width: 600px;
296
- margin: 0 auto;
297
- padding: 20px;
298
- }
299
- .center-content h1 {
300
- font-weight: 600;
301
- margin-bottom: 1rem;
302
- }
303
- .center-content p {
304
- margin-bottom: 1.5rem;
305
- }
306
- """
307
-
308
-
309
- flux_models = ["black-forest-labs/FLUX.1-dev"]
310
- with open("data/images/loras/flux.json", "r") as f:
311
- loras = json.load(f)
312
-
313
-
314
- # Event functions
315
- def update_fast_generation(model, fast_generation):
316
- if fast_generation:
317
- return (
318
- gr.update(
319
- value=3.5
320
- ),
321
- gr.update(
322
- value=8
323
- )
324
- )
325
-
326
-
327
- def selected_lora_from_gallery(evt: gr.SelectData):
328
- return (
329
- gr.update(
330
- value=evt.index
331
- )
332
- )
333
-
334
-
335
- def update_selected_lora(custom_lora):
336
- link = custom_lora.split("/")
337
-
338
- if len(link) == 2:
339
- model_card = ModelCard.load(custom_lora)
340
- trigger_word = model_card.data.get("instance_prompt", "")
341
- image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
342
-
343
- custom_lora_info_css = """
344
- <style>
345
- .custom-lora-info {
346
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
347
- background: linear-gradient(135deg, #4a90e2, #7b61ff);
348
- color: white;
349
- padding: 16px;
350
- border-radius: 8px;
351
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
352
- margin: 16px 0;
353
- }
354
- .custom-lora-header {
355
- font-size: 18px;
356
- font-weight: 600;
357
- margin-bottom: 12px;
358
- }
359
- .custom-lora-content {
360
- display: flex;
361
- align-items: center;
362
- background-color: rgba(255, 255, 255, 0.1);
363
- border-radius: 6px;
364
- padding: 12px;
365
- }
366
- .custom-lora-image {
367
- width: 80px;
368
- height: 80px;
369
- object-fit: cover;
370
- border-radius: 6px;
371
- margin-right: 16px;
372
- }
373
- .custom-lora-text h3 {
374
- margin: 0 0 8px 0;
375
- font-size: 16px;
376
- font-weight: 600;
377
- }
378
- .custom-lora-text small {
379
- font-size: 14px;
380
- opacity: 0.9;
381
- }
382
- .custom-trigger-word {
383
- background-color: rgba(255, 255, 255, 0.2);
384
- padding: 2px 6px;
385
- border-radius: 4px;
386
- font-weight: 600;
387
- }
388
- </style>
389
- """
390
-
391
- custom_lora_info_html = f"""
392
- <div class="custom-lora-info">
393
- <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
394
- <div class="custom-lora-content">
395
- <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
396
- <div class="custom-lora-text">
397
- <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
398
- <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
399
- </div>
400
- </div>
401
- </div>
402
- """
403
-
404
- custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
405
-
406
- return (
407
- gr.update( # selected_lora
408
- value=custom_lora,
409
- ),
410
- gr.update( # custom_lora_info
411
- value=custom_lora_info_html,
412
- visible=True
413
- )
414
- )
415
-
416
- else:
417
- return (
418
- gr.update( # selected_lora
419
- value=custom_lora,
420
- ),
421
- gr.update( # custom_lora_info
422
- value=custom_lora_info_html if len(link) == 0 else "",
423
- visible=False
424
- )
425
- )
426
-
427
-
428
- def add_to_enabled_loras(model, selected_lora, enabled_loras):
429
- lora_data = loras
430
- try:
431
- selected_lora = int(selected_lora)
432
-
433
- if 0 <= selected_lora: # is the index of the lora in the gallery
434
- lora_info = lora_data[selected_lora]
435
- enabled_loras.append({
436
- "repo_id": lora_info["repo"],
437
- "trigger_word": lora_info["trigger_word"]
438
- })
439
- except ValueError:
440
- link = selected_lora.split("/")
441
- if len(link) == 2:
442
- model_card = ModelCard.load(selected_lora)
443
- trigger_word = model_card.data.get("instance_prompt", "")
444
- enabled_loras.append({
445
- "repo_id": selected_lora,
446
- "trigger_word": trigger_word
447
- })
448
-
449
- return (
450
- gr.update( # selected_lora
451
- value=""
452
- ),
453
- gr.update( # custom_lora_info
454
- value="",
455
- visible=False
456
- ),
457
- gr.update( # enabled_loras
458
- value=enabled_loras
459
- )
460
  )
461
 
462
-
463
- def update_lora_sliders(enabled_loras):
464
- sliders = []
465
- remove_buttons = []
466
-
467
- for lora in enabled_loras:
468
- sliders.append(
469
- gr.update(
470
- label=lora.get("repo_id", ""),
471
- info=f"Trigger Word: {lora.get('trigger_word', '')}",
472
- visible=True,
473
- interactive=True
474
- )
475
- )
476
- remove_buttons.append(
477
- gr.update(
478
- visible=True,
479
- interactive=True
480
- )
481
- )
482
-
483
- if len(sliders) < 6:
484
- for i in range(len(sliders), 6):
485
- sliders.append(
486
- gr.update(
487
- visible=False
488
- )
489
- )
490
- remove_buttons.append(
491
- gr.update(
492
- visible=False
493
- )
494
- )
495
-
496
- return *sliders, *remove_buttons
497
-
498
-
499
- def remove_from_enabled_loras(enabled_loras, index):
500
- enabled_loras.pop(index)
501
- return (
502
- gr.update(
503
- value=enabled_loras
504
- )
505
- )
506
-
507
-
508
- @spaces.GPU
509
- def generate_image(
510
- model, prompt, fast_generation, enabled_loras,
511
- lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
512
- img2img_image, inpaint_image, canny_image, pose_image, depth_image,
513
- img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength,
514
- resize_mode,
515
- scheduler, image_height, image_width, image_num_images_per_prompt,
516
- image_num_inference_steps, image_guidance_scale, image_seed,
517
- refiner, vae
518
- ):
519
- base_args = {
520
- "model": model,
521
- "prompt": prompt,
522
- "fast_generation": fast_generation,
523
- "loras": None,
524
- "resize_mode": resize_mode,
525
- "scheduler": scheduler,
526
- "height": int(image_height),
527
- "width": int(image_width),
528
- "num_images_per_prompt": float(image_num_images_per_prompt),
529
- "num_inference_steps": float(image_num_inference_steps),
530
- "guidance_scale": float(image_guidance_scale),
531
- "seed": int(image_seed),
532
- "refiner": refiner,
533
- "vae": vae,
534
- "controlnet_config": None,
535
- }
536
- base_args = FluxReq(**base_args)
537
-
538
- if len(enabled_loras) > 0:
539
- base_args.loras = []
540
- for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
541
- if enabled_lora['repo_id']:
542
- base_args.loras.append({
543
- "repo_id": enabled_lora['repo_id'],
544
- "weight": slider
545
- })
546
-
547
- image = None
548
- mask_image = None
549
- strength = None
550
-
551
- if img2img_image:
552
- image = img2img_image
553
- strength = float(img2img_strength)
554
-
555
- base_args = FluxImg2ImgReq(
556
- **base_args.__dict__,
557
- image=image,
558
- strength=strength
559
- )
560
- elif inpaint_image:
561
- image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
562
- mask_image = inpaint_image['layers'][0] if image else None
563
- strength = float(inpaint_strength)
564
-
565
- if image and mask_image:
566
- base_args = FluxInpaintReq(
567
- **base_args.__dict__,
568
- image=image,
569
- mask_image=mask_image,
570
- strength=strength
571
- )
572
- elif any([canny_image, pose_image, depth_image]):
573
- base_args.controlnet_config = ControlNetReq(
574
- controlnets=[],
575
- control_images=[],
576
- controlnet_conditioning_scale=[]
577
- )
578
-
579
- if canny_image:
580
- base_args.controlnet_config.controlnets.append("canny")
581
- base_args.controlnet_config.control_images.append(canny_image)
582
- base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
583
- if pose_image:
584
- base_args.controlnet_config.controlnets.append("pose")
585
- base_args.controlnet_config.control_images.append(pose_image)
586
- base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
587
- if depth_image:
588
- base_args.controlnet_config.controlnets.append("depth")
589
- base_args.controlnet_config.control_images.append(depth_image)
590
- base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
591
- else:
592
- base_args = FluxReq(**base_args.__dict__)
593
-
594
- return gr.update(
595
- value=gen_img(base_args),
596
- interactive=True
597
- )
598
-
599
-
600
- # Main Gradio app
601
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
602
- # Header
603
- with gr.Column(elem_classes="center-content"):
604
- gr.Markdown("""
605
- # πŸš€ AAI: All AI
606
- Unleash your creativity with our multi-modal AI platform.
607
- [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
608
- """)
609
-
610
- # Tabs
611
- with gr.Tabs():
612
- with gr.Tab(label="πŸ–ΌοΈ Image"):
613
- with gr.Tabs():
614
- with gr.Tab("Flux"):
615
- """
616
- Create the image tab for Generative Image Generation Models
617
-
618
- Args:
619
- models: list
620
- A list containing the models repository paths
621
- gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
622
- A list of dictionaries containing the title and component for the custom gradio component
623
- Example:
624
- def gr_comp():
625
- gr.Label("Hello World")
626
-
627
- [
628
- {
629
- 'title': "Title",
630
- 'component': gr_comp()
631
- }
632
- ]
633
- loras: list
634
- A list of dictionaries containing the image and title for the Loras Gallery
635
- Generally a loaded json file from the data folder
636
-
637
- """
638
- def process_gaps(gaps: List[dict]):
639
- for gap in gaps:
640
- with gr.Accordion(gap['title']):
641
- gap['component']
642
-
643
-
644
- with gr.Row():
645
- with gr.Column():
646
- with gr.Group() as image_options:
647
- model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
648
- prompt = gr.Textbox(lines=5, label="Prompt")
649
- fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
650
-
651
-
652
- with gr.Accordion("Loras", open=True): # Lora Gallery
653
- lora_gallery = gr.Gallery(
654
- label="Gallery",
655
- value=[(lora['image'], lora['title']) for lora in loras],
656
- allow_preview=False,
657
- columns=3,
658
- rows=3,
659
- type="pil"
660
- )
661
-
662
- with gr.Group():
663
- with gr.Column():
664
- with gr.Row():
665
- custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
666
- selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
667
-
668
- custom_lora_info = gr.HTML(visible=False)
669
- add_lora = gr.Button(value="Add LoRA")
670
-
671
- enabled_loras = gr.State(value=[])
672
- with gr.Group():
673
- with gr.Row():
674
- for i in range(6): # only support max 6 loras due to inference time
675
- with gr.Column():
676
- with gr.Column(scale=2):
677
- globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
678
- with gr.Column():
679
- globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
680
-
681
-
682
- with gr.Accordion("Embeddings", open=False): # Embeddings
683
- gr.Label("To be implemented")
684
-
685
-
686
- with gr.Accordion("Image Options"): # Image Options
687
- with gr.Tabs():
688
- image_options = {
689
- "img2img": "Upload Image",
690
- "inpaint": "Upload Image",
691
- "canny": "Upload Image",
692
- "pose": "Upload Image",
693
- "depth": "Upload Image",
694
- }
695
-
696
- for image_option, label in image_options.items():
697
- with gr.Tab(image_option):
698
- if not image_option in ['inpaint', 'scribble']:
699
- globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
700
- elif image_option in ['inpaint', 'scribble']:
701
- globals()[f"{image_option}_image"] = gr.ImageEditor(
702
- label=label,
703
- image_mode='RGB',
704
- layers=False,
705
- brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
706
- interactive=True,
707
- type="pil",
708
- )
709
-
710
- # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
711
- globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
712
-
713
- resize_mode = gr.Radio(
714
- label="Resize Mode",
715
- choices=["crop and resize", "resize only", "resize and fill"],
716
- value="resize and fill",
717
- interactive=True
718
- )
719
-
720
-
721
- with gr.Column():
722
- with gr.Group():
723
- output_images = gr.Gallery(
724
- label="Output Images",
725
- value=[],
726
- allow_preview=True,
727
- type="pil",
728
- interactive=False,
729
- )
730
- generate_images = gr.Button(value="Generate Images", variant="primary")
731
-
732
- with gr.Accordion("Advance Settings", open=True):
733
- with gr.Row():
734
- scheduler = gr.Dropdown(
735
- label="Scheduler",
736
- choices = [
737
- "fm_euler"
738
- ],
739
- value="fm_euler",
740
- interactive=True
741
- )
742
-
743
- with gr.Row():
744
- for column in range(2):
745
- with gr.Column():
746
- options = [
747
- ("Height", "image_height", 64, 1024, 64, 1024, True),
748
- ("Width", "image_width", 64, 1024, 64, 1024, True),
749
- ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
750
- ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
751
- ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
752
- ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
753
- ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
754
- ]
755
- for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
756
- globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
757
-
758
- with gr.Row():
759
- refiner = gr.Checkbox(
760
- label="Refiner πŸ§ͺ",
761
- value=False,
762
- )
763
- vae = gr.Checkbox(
764
- label="VAE",
765
- value=True,
766
- )
767
-
768
-
769
- # Events
770
- # Base Options
771
- fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
772
-
773
-
774
- # Lora Gallery
775
- lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
776
- custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
777
- add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
778
- enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
779
-
780
- for i in range(6):
781
- globals()[f"lora_remove_{i}"].click(
782
- lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
783
- [enabled_loras],
784
- [enabled_loras]
785
- )
786
-
787
-
788
- # Generate Image
789
- generate_images.click(
790
- generate_image, # type: ignore
791
- [
792
- model, prompt, fast_generation, enabled_loras,
793
- lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
794
- img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
795
- img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
796
- resize_mode,
797
- scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
798
- image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
799
- refiner, vae
800
- ],
801
- [output_images]
802
- )
803
- with gr.Tab("SDXL"):
804
- gr.Label("To be implemented")
805
- with gr.Tab(label="🎡 Audio"):
806
- gr.Label("Coming soon!")
807
- with gr.Tab(label="🎬 Video"):
808
- gr.Label("Coming soon!")
809
- with gr.Tab(label="πŸ“„ Text"):
810
- gr.Label("Coming soon!")
811
-
812
-
813
- demo.launch(
814
- share=False,
815
- debug=True,
816
- )
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from config import css
4
+ from tabs.image_tab import image_tab
5
+ from tabs.audio_tab import audio_tab
6
+ from tabs.video_tab import video_tab
7
+ from tabs.text_tab import text_tab
8
+
9
+ def main():
10
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
11
+ # Header
12
+ with gr.Column(elem_classes="center-content"):
13
+ gr.Markdown("""
14
+ # πŸš€ AAI: All AI
15
+ Unleash your creativity with our multi-modal AI platform.
16
+ [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
17
+ """)
18
+
19
+ # Tabs
20
+ with gr.Tabs():
21
+ image_tab()
22
+ audio_tab()
23
+ video_tab()
24
+ text_tab()
25
+
26
+ demo.launch(
27
+ share=False,
28
+ debug=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
+ if __name__ == "__main__":
32
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import json
4
+
5
+ css = """
6
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
7
+ body {
8
+ font-family: 'Poppins', sans-serif !important;
9
+ }
10
+ .center-content {
11
+ text-align: center;
12
+ max-width: 600px;
13
+ margin: 0 auto;
14
+ padding: 20px;
15
+ }
16
+ .center-content h1 {
17
+ font-weight: 600;
18
+ margin-bottom: 1rem;
19
+ }
20
+ .center-content p {
21
+ margin-bottom: 1.5rem;
22
+ }
23
+ """
24
+
25
+
26
+ # Models
27
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
28
+ sdxl_models = ["stabilityai/stable-diffusion-xl-base-1.0"]
29
+
30
+
31
+ # Load LoRAs
32
+ with open("data/loras/flux.json", "r") as f:
33
+ flux_loras = json.load(f)
34
+
35
+ with open("data/loras/sdxl.json", "r") as f:
36
+ sdxl_loras = json.load(f)
data/{images/loras β†’ loras}/flux.json RENAMED
File without changes
data/{images/loras β†’ loras}/sdxl.json RENAMED
File without changes
modules/events/flux_events.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ import spaces
5
+ import gradio as gr
6
+ from huggingface_hub import ModelCard
7
+
8
+ from modules.helpers.flux_helpers import FluxReq, FluxImg2ImgReq, FluxInpaintReq, ControlNetReq, gen_img
9
+ from config import flux_models, flux_loras
10
+
11
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
12
+ loras = flux_loras
13
+
14
+
15
+ # Event functions
16
+ def update_fast_generation(model, fast_generation):
17
+ if fast_generation:
18
+ return (
19
+ gr.update(
20
+ value=3.5
21
+ ),
22
+ gr.update(
23
+ value=8
24
+ )
25
+ )
26
+
27
+
28
+ def selected_lora_from_gallery(evt: gr.SelectData):
29
+ return (
30
+ gr.update(
31
+ value=evt.index
32
+ )
33
+ )
34
+
35
+
36
+ def update_selected_lora(custom_lora):
37
+ link = custom_lora.split("/")
38
+
39
+ if len(link) == 2:
40
+ model_card = ModelCard.load(custom_lora)
41
+ trigger_word = model_card.data.get("instance_prompt", "")
42
+ image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
43
+
44
+ custom_lora_info_css = """
45
+ <style>
46
+ .custom-lora-info {
47
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
48
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
49
+ color: white;
50
+ padding: 16px;
51
+ border-radius: 8px;
52
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
53
+ margin: 16px 0;
54
+ }
55
+ .custom-lora-header {
56
+ font-size: 18px;
57
+ font-weight: 600;
58
+ margin-bottom: 12px;
59
+ }
60
+ .custom-lora-content {
61
+ display: flex;
62
+ align-items: center;
63
+ background-color: rgba(255, 255, 255, 0.1);
64
+ border-radius: 6px;
65
+ padding: 12px;
66
+ }
67
+ .custom-lora-image {
68
+ width: 80px;
69
+ height: 80px;
70
+ object-fit: cover;
71
+ border-radius: 6px;
72
+ margin-right: 16px;
73
+ }
74
+ .custom-lora-text h3 {
75
+ margin: 0 0 8px 0;
76
+ font-size: 16px;
77
+ font-weight: 600;
78
+ }
79
+ .custom-lora-text small {
80
+ font-size: 14px;
81
+ opacity: 0.9;
82
+ }
83
+ .custom-trigger-word {
84
+ background-color: rgba(255, 255, 255, 0.2);
85
+ padding: 2px 6px;
86
+ border-radius: 4px;
87
+ font-weight: 600;
88
+ }
89
+ </style>
90
+ """
91
+
92
+ custom_lora_info_html = f"""
93
+ <div class="custom-lora-info">
94
+ <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
95
+ <div class="custom-lora-content">
96
+ <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
97
+ <div class="custom-lora-text">
98
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
99
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
100
+ </div>
101
+ </div>
102
+ </div>
103
+ """
104
+
105
+ custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
106
+
107
+ return (
108
+ gr.update( # selected_lora
109
+ value=custom_lora,
110
+ ),
111
+ gr.update( # custom_lora_info
112
+ value=custom_lora_info_html,
113
+ visible=True
114
+ )
115
+ )
116
+
117
+ else:
118
+ return (
119
+ gr.update( # selected_lora
120
+ value=custom_lora,
121
+ ),
122
+ gr.update( # custom_lora_info
123
+ value=custom_lora_info_html if len(link) == 0 else "",
124
+ visible=False
125
+ )
126
+ )
127
+
128
+
129
+ def add_to_enabled_loras(model, selected_lora, enabled_loras):
130
+ lora_data = loras
131
+ try:
132
+ selected_lora = int(selected_lora)
133
+
134
+ if 0 <= selected_lora: # is the index of the lora in the gallery
135
+ lora_info = lora_data[selected_lora]
136
+ enabled_loras.append({
137
+ "repo_id": lora_info["repo"],
138
+ "trigger_word": lora_info["trigger_word"]
139
+ })
140
+ except ValueError:
141
+ link = selected_lora.split("/")
142
+ if len(link) == 2:
143
+ model_card = ModelCard.load(selected_lora)
144
+ trigger_word = model_card.data.get("instance_prompt", "")
145
+ enabled_loras.append({
146
+ "repo_id": selected_lora,
147
+ "trigger_word": trigger_word
148
+ })
149
+
150
+ return (
151
+ gr.update( # selected_lora
152
+ value=""
153
+ ),
154
+ gr.update( # custom_lora_info
155
+ value="",
156
+ visible=False
157
+ ),
158
+ gr.update( # enabled_loras
159
+ value=enabled_loras
160
+ )
161
+ )
162
+
163
+
164
+ def update_lora_sliders(enabled_loras):
165
+ sliders = []
166
+ remove_buttons = []
167
+
168
+ for lora in enabled_loras:
169
+ sliders.append(
170
+ gr.update(
171
+ label=lora.get("repo_id", ""),
172
+ info=f"Trigger Word: {lora.get('trigger_word', '')}",
173
+ visible=True,
174
+ interactive=True
175
+ )
176
+ )
177
+ remove_buttons.append(
178
+ gr.update(
179
+ visible=True,
180
+ interactive=True
181
+ )
182
+ )
183
+
184
+ if len(sliders) < 6:
185
+ for i in range(len(sliders), 6):
186
+ sliders.append(
187
+ gr.update(
188
+ visible=False
189
+ )
190
+ )
191
+ remove_buttons.append(
192
+ gr.update(
193
+ visible=False
194
+ )
195
+ )
196
+
197
+ return *sliders, *remove_buttons
198
+
199
+
200
+ def remove_from_enabled_loras(enabled_loras, index):
201
+ enabled_loras.pop(index)
202
+ return (
203
+ gr.update(
204
+ value=enabled_loras
205
+ )
206
+ )
207
+
208
+
209
+ @spaces.GPU
210
+ def generate_image(
211
+ model, prompt, fast_generation, enabled_loras,
212
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
213
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image,
214
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength,
215
+ resize_mode,
216
+ scheduler, image_height, image_width, image_num_images_per_prompt,
217
+ image_num_inference_steps, image_guidance_scale, image_seed,
218
+ refiner, vae
219
+ ):
220
+ base_args = {
221
+ "model": model,
222
+ "prompt": prompt,
223
+ "fast_generation": fast_generation,
224
+ "loras": None,
225
+ "resize_mode": resize_mode,
226
+ "scheduler": scheduler,
227
+ "height": int(image_height),
228
+ "width": int(image_width),
229
+ "num_images_per_prompt": float(image_num_images_per_prompt),
230
+ "num_inference_steps": float(image_num_inference_steps),
231
+ "guidance_scale": float(image_guidance_scale),
232
+ "seed": int(image_seed),
233
+ "refiner": refiner,
234
+ "vae": vae,
235
+ "controlnet_config": None,
236
+ }
237
+ base_args = FluxReq(**base_args)
238
+
239
+ if len(enabled_loras) > 0:
240
+ base_args.loras = []
241
+ for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
242
+ if enabled_lora['repo_id']:
243
+ base_args.loras.append({
244
+ "repo_id": enabled_lora['repo_id'],
245
+ "weight": slider
246
+ })
247
+
248
+ image = None
249
+ mask_image = None
250
+ strength = None
251
+
252
+ if img2img_image:
253
+ image = img2img_image
254
+ strength = float(img2img_strength)
255
+
256
+ base_args = FluxImg2ImgReq(
257
+ **base_args.__dict__,
258
+ image=image,
259
+ strength=strength
260
+ )
261
+ elif inpaint_image:
262
+ image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
263
+ mask_image = inpaint_image['layers'][0] if image else None
264
+ strength = float(inpaint_strength)
265
+
266
+ if image and mask_image:
267
+ base_args = FluxInpaintReq(
268
+ **base_args.__dict__,
269
+ image=image,
270
+ mask_image=mask_image,
271
+ strength=strength
272
+ )
273
+ elif any([canny_image, pose_image, depth_image]):
274
+ base_args.controlnet_config = ControlNetReq(
275
+ controlnets=[],
276
+ control_images=[],
277
+ controlnet_conditioning_scale=[]
278
+ )
279
+
280
+ if canny_image:
281
+ base_args.controlnet_config.controlnets.append("canny")
282
+ base_args.controlnet_config.control_images.append(canny_image)
283
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
284
+ if pose_image:
285
+ base_args.controlnet_config.controlnets.append("pose")
286
+ base_args.controlnet_config.control_images.append(pose_image)
287
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
288
+ if depth_image:
289
+ base_args.controlnet_config.controlnets.append("depth")
290
+ base_args.controlnet_config.control_images.append(depth_image)
291
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
292
+ else:
293
+ base_args = FluxReq(**base_args.__dict__)
294
+
295
+ return gr.update(
296
+ value=gen_img(base_args),
297
+ interactive=True
298
+ )
src/ui/audios.py β†’ modules/events/sdxl_events.py RENAMED
File without changes
modules/helpers/common_helpers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from typing import List, Optional, Dict, Any
3
+
4
+ import torch
5
+ from pydantic import BaseModel
6
+ from PIL import Image
7
+ from diffusers.schedulers import *
8
+ from controlnet_aux.processor import Processor
9
+
10
+ from .flux_helpers import ControlNetReq
11
+
12
+
13
+ class BaseReq(BaseModel):
14
+ model: str = ""
15
+ prompt: str = ""
16
+ fast_generation: Optional[bool] = True
17
+ loras: Optional[list] = []
18
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
19
+ scheduler: Optional[str] = "euler_fl"
20
+ height: int = 1024
21
+ width: int = 1024
22
+ num_images_per_prompt: int = 1
23
+ num_inference_steps: int = 8
24
+ guidance_scale: float = 3.5
25
+ seed: Optional[int] = 0
26
+ refiner: bool = False
27
+ vae: bool = True
28
+ controlnet_config: Optional[ControlNetReq] = None
29
+ custom_addons: Optional[Dict[Any, Any]] = None
30
+
31
+ class Config:
32
+ arbitrary_types_allowed=True
33
+
34
+
35
+ class BaseImg2ImgReq(BaseReq):
36
+ image: Image.Image
37
+ strength: float = 1.0
38
+
39
+ class Config:
40
+ arbitrary_types_allowed=True
41
+
42
+
43
+ class BaseInpaintReq(BaseImg2ImgReq):
44
+ mask_image: Image.Image
45
+
46
+ class Config:
47
+ arbitrary_types_allowed=True
48
+
49
+
50
+ def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
51
+ for image in images:
52
+ if resize_mode == "resize_only":
53
+ image = image.resize((width, height))
54
+ elif resize_mode == "crop_and_resize":
55
+ image = image.crop((0, 0, width, height))
56
+ elif resize_mode == "resize_and_fill":
57
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
58
+
59
+ return images
60
+
61
+
62
+ def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str):
63
+ response_images = []
64
+ control_images = resize_images(controlnet_config.control_images, height, width, resize_mode)
65
+ for controlnet, image in zip(controlnet_config.controlnets, control_images):
66
+ if controlnet == "canny":
67
+ processor = Processor('canny')
68
+ elif controlnet == "depth":
69
+ processor = Processor('depth_midas')
70
+ elif controlnet == "pose":
71
+ processor = Processor('openpose_full')
72
+ else:
73
+ raise ValueError(f"Invalid Controlnet: {controlnet}")
74
+
75
+ response_images.append(processor(image, to_pil=True))
76
+
77
+ return response_images
78
+
79
+
80
+ def cleanup(pipeline, loras = None):
81
+ if loras:
82
+ pipeline.unload_lora_weights()
83
+ gc.collect()
84
+ torch.cuda.empty_cache()
modules/helpers/flux_helpers.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import (
6
+ AutoPipelineForText2Image,
7
+ AutoPipelineForImage2Image,
8
+ AutoPipelineForInpainting,
9
+ DiffusionPipeline,
10
+ AutoencoderKL,
11
+ FluxControlNetModel,
12
+ FluxMultiControlNetModel,
13
+ )
14
+ from huggingface_hub import hf_hub_download
15
+ from diffusers.schedulers import *
16
+ from huggingface_hub import hf_hub_download
17
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
18
+
19
+ from .common_helpers import *
20
+
21
+
22
+ def load_sd():
23
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ # Models
27
+ models = [
28
+ {
29
+ "repo_id": "black-forest-labs/FLUX.1-dev",
30
+ "loader": "flux",
31
+ "compute_type": torch.bfloat16,
32
+ }
33
+ ]
34
+
35
+ for model in models:
36
+ try:
37
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
38
+ model['repo_id'],
39
+ vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device),
40
+ torch_dtype = model['compute_type'],
41
+ safety_checker = None,
42
+ variant = "fp16"
43
+ ).to(device)
44
+ except:
45
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
46
+ model['repo_id'],
47
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
48
+ torch_dtype = model['compute_type'],
49
+ safety_checker = None
50
+ ).to(device)
51
+
52
+ model["pipeline"].enable_model_cpu_offload()
53
+
54
+
55
+ # VAE n Refiner
56
+ flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
57
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
58
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
59
+ refiner.enable_model_cpu_offload()
60
+
61
+
62
+ # ControlNet
63
+ controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
64
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
65
+ torch_dtype=torch.bfloat16
66
+ ).to(device)])
67
+
68
+ return device, models, flux_vae, sdxl_vae, refiner, controlnet
69
+
70
+
71
+ device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
72
+
73
+
74
+ def get_control_mode(controlnet_config: ControlNetReq):
75
+ control_mode = []
76
+ layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
77
+
78
+ for c in controlnet_config.controlnets:
79
+ if c in layers:
80
+ control_mode.append(layers.index(c))
81
+
82
+ return control_mode
83
+
84
+
85
+ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
86
+ for m in models:
87
+ if m['repo_id'] == request.model:
88
+ pipe_args = {
89
+ "pipeline": m['pipeline'],
90
+ }
91
+
92
+
93
+ # Set ControlNet config
94
+ if request.controlnet_config:
95
+ pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
96
+ pipe_args["controlnet"] = [controlnet]
97
+
98
+
99
+ # Choose Pipeline Mode
100
+ if isinstance(request, BaseReq):
101
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
102
+ elif isinstance(request, BaseImg2ImgReq):
103
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
104
+ elif isinstance(request, BaseInpaintReq):
105
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
106
+
107
+
108
+ # Enable or Disable Refiner
109
+ if request.vae:
110
+ pipe_args["pipeline"].vae = flux_vae
111
+ elif not request.vae:
112
+ pipe_args["pipeline"].vae = None
113
+
114
+
115
+ # Set Scheduler
116
+ pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
117
+
118
+
119
+ # Set Loras
120
+ if request.loras:
121
+ for i, lora in enumerate(request.loras):
122
+ pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}")
123
+ adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
124
+ adapter_weights = [lora['weight'] for lora in request.loras]
125
+
126
+ if request.fast_generation:
127
+ hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
128
+ hyper_weight = 0.125
129
+ pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
130
+ adapter_names.append("hyper_lora")
131
+ adapter_weights.append(hyper_weight)
132
+
133
+ pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
134
+
135
+ return pipe_args
136
+
137
+
138
+ def get_prompt_attention(pipeline, prompt):
139
+ return get_weighted_text_embeddings_flux1(pipeline, prompt)
140
+
141
+
142
+ # Gen Function
143
+ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
144
+ pipe_args = get_pipe(request)
145
+ pipeline = pipe_args["pipeline"]
146
+ try:
147
+ positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
148
+
149
+ # Common Args
150
+ args = {
151
+ 'prompt_embeds': positive_prompt_embeds,
152
+ 'pooled_prompt_embeds': positive_prompt_pooled,
153
+ 'height': request.height,
154
+ 'width': request.width,
155
+ 'num_images_per_prompt': request.num_images_per_prompt,
156
+ 'num_inference_steps': request.num_inference_steps,
157
+ 'guidance_scale': request.guidance_scale,
158
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
159
+ }
160
+
161
+ if request.controlnet_config:
162
+ args['control_mode'] = get_control_mode(request.controlnet_config)
163
+ args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
164
+ args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
165
+
166
+ if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
167
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
168
+ args['strength'] = request.strength
169
+
170
+ if isinstance(request, BaseInpaintReq):
171
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
172
+
173
+ # Generate
174
+ images = pipeline(**args).images
175
+
176
+ # Refiner
177
+ if request.refiner:
178
+ images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
179
+
180
+ cleanup(pipeline, request.loras)
181
+
182
+ return images
183
+ except Exception as e:
184
+ cleanup(pipeline, request.loras)
185
+ raise gr.Error(f"Error: {e}")
src/ui/talkinghead.py β†’ modules/helpers/sdxl_helpers.py RENAMED
File without changes
modules/pipelines/flux_pipelines.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/pipelines/flux_pipelines.py
2
+
3
+ import torch
4
+ from diffusers import AutoPipelineForText2Image, AutoencoderKL
5
+
6
+ def load_flux():
7
+ # Load FLUX models and pipelines
8
+ # ...
9
+ return device, models, flux_vae, controlnet
10
+
11
+ # modules/pipelines/sdxl_pipelines.py
12
+
13
+ import torch
14
+ from diffusers import AutoPipelineForText2Image, AutoencoderKL
15
+
16
+ def load_sdxl():
17
+ # Load SDXL models and pipelines
18
+ # ...
19
+ return device, models, sdxl_vae, controlnet
src/ui/texts.py β†’ modules/pipelines/sdxl_pipelines.py RENAMED
File without changes
old/app.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Testing one file gradio app for zero gpu spaces not working as expected.
2
+ # Check here for the issue: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/106#66e278a396acd45223e0d00b
3
+
4
+ import os
5
+ import gc
6
+ import json
7
+ import random
8
+ from typing import List, Optional
9
+
10
+ import spaces
11
+ import gradio as gr
12
+ from huggingface_hub import ModelCard
13
+ import torch
14
+ from pydantic import BaseModel
15
+ from PIL import Image
16
+ from diffusers import (
17
+ AutoPipelineForText2Image,
18
+ AutoPipelineForImage2Image,
19
+ AutoPipelineForInpainting,
20
+ DiffusionPipeline,
21
+ AutoencoderKL,
22
+ FluxControlNetModel,
23
+ FluxMultiControlNetModel,
24
+ )
25
+ from huggingface_hub import hf_hub_download
26
+ from diffusers.schedulers import *
27
+ from huggingface_hub import hf_hub_download
28
+ from controlnet_aux.processor import Processor
29
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
30
+
31
+
32
+ # Initialize System
33
+ os.system("pip install --upgrade pip")
34
+
35
+
36
+ def load_sd():
37
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ # Models
41
+ models = [
42
+ {
43
+ "repo_id": "black-forest-labs/FLUX.1-dev",
44
+ "loader": "flux",
45
+ "compute_type": torch.bfloat16,
46
+ }
47
+ ]
48
+
49
+ for model in models:
50
+ try:
51
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
52
+ model['repo_id'],
53
+ vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device),
54
+ torch_dtype = model['compute_type'],
55
+ safety_checker = None,
56
+ variant = "fp16"
57
+ ).to(device)
58
+ except:
59
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
60
+ model['repo_id'],
61
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
62
+ torch_dtype = model['compute_type'],
63
+ safety_checker = None
64
+ ).to(device)
65
+
66
+ model["pipeline"].enable_model_cpu_offload()
67
+
68
+
69
+ # VAE n Refiner
70
+ flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
71
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
72
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
73
+ refiner.enable_model_cpu_offload()
74
+
75
+
76
+ # ControlNet
77
+ controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
78
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
79
+ torch_dtype=torch.bfloat16
80
+ ).to(device)])
81
+
82
+ return device, models, flux_vae, sdxl_vae, refiner, controlnet
83
+
84
+
85
+ device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
86
+
87
+
88
+ # Models
89
+ class ControlNetReq(BaseModel):
90
+ controlnets: List[str] # ["canny", "tile", "depth"]
91
+ control_images: List[Image.Image]
92
+ controlnet_conditioning_scale: List[float]
93
+
94
+ class Config:
95
+ arbitrary_types_allowed=True
96
+
97
+
98
+ class FluxReq(BaseModel):
99
+ model: str = ""
100
+ prompt: str = ""
101
+ fast_generation: Optional[bool] = True
102
+ loras: Optional[list] = []
103
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
104
+ scheduler: Optional[str] = "euler_fl"
105
+ height: int = 1024
106
+ width: int = 1024
107
+ num_images_per_prompt: int = 1
108
+ num_inference_steps: int = 8
109
+ guidance_scale: float = 3.5
110
+ seed: Optional[int] = 0
111
+ refiner: bool = False
112
+ vae: bool = True
113
+ controlnet_config: Optional[ControlNetReq] = None
114
+
115
+ class Config:
116
+ arbitrary_types_allowed=True
117
+
118
+
119
+ class FluxImg2ImgReq(FluxReq):
120
+ image: Image.Image
121
+ strength: float = 1.0
122
+
123
+ class Config:
124
+ arbitrary_types_allowed=True
125
+
126
+
127
+ class FluxInpaintReq(FluxImg2ImgReq):
128
+ mask_image: Image.Image
129
+
130
+ class Config:
131
+ arbitrary_types_allowed=True
132
+
133
+
134
+ # Helper Functions
135
+ def get_control_mode(controlnet_config: ControlNetReq):
136
+ control_mode = []
137
+ layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
138
+
139
+ for c in controlnet_config.controlnets:
140
+ if c in layers:
141
+ control_mode.append(layers.index(c))
142
+
143
+ return control_mode
144
+
145
+
146
+ def get_pipe(request: FluxReq | FluxImg2ImgReq | FluxInpaintReq):
147
+ for m in models:
148
+ if m['repo_id'] == request.model:
149
+ pipe_args = {
150
+ "pipeline": m['pipeline'],
151
+ }
152
+
153
+
154
+ # Set ControlNet config
155
+ if request.controlnet_config:
156
+ pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
157
+ pipe_args["controlnet"] = [controlnet]
158
+
159
+
160
+ # Choose Pipeline Mode
161
+ if isinstance(request, FluxReq):
162
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
163
+ elif isinstance(request, FluxImg2ImgReq):
164
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
165
+ elif isinstance(request, FluxInpaintReq):
166
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
167
+
168
+
169
+ # Enable or Disable Refiner
170
+ if request.vae:
171
+ pipe_args["pipeline"].vae = flux_vae
172
+ elif not request.vae:
173
+ pipe_args["pipeline"].vae = None
174
+
175
+
176
+ # Set Scheduler
177
+ pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
178
+
179
+
180
+ # Set Loras
181
+ if request.loras:
182
+ for i, lora in enumerate(request.loras):
183
+ pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}")
184
+ adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
185
+ adapter_weights = [lora['weight'] for lora in request.loras]
186
+
187
+ if request.fast_generation:
188
+ hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
189
+ hyper_weight = 0.125
190
+ pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
191
+ adapter_names.append("hyper_lora")
192
+ adapter_weights.append(hyper_weight)
193
+
194
+ pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
195
+
196
+ return pipe_args
197
+
198
+
199
+ def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
200
+ for image in images:
201
+ if resize_mode == "resize_only":
202
+ image = image.resize((width, height))
203
+ elif resize_mode == "crop_and_resize":
204
+ image = image.crop((0, 0, width, height))
205
+ elif resize_mode == "resize_and_fill":
206
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
207
+
208
+ return images
209
+
210
+
211
+ def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str):
212
+ response_images = []
213
+ control_images = resize_images(controlnet_config.control_images, height, width, resize_mode)
214
+ for controlnet, image in zip(controlnet_config.controlnets, control_images):
215
+ if controlnet == "canny":
216
+ processor = Processor('canny')
217
+ elif controlnet == "depth":
218
+ processor = Processor('depth_midas')
219
+ elif controlnet == "pose":
220
+ processor = Processor('openpose_full')
221
+ else:
222
+ raise ValueError(f"Invalid Controlnet: {controlnet}")
223
+
224
+ response_images.append(processor(image, to_pil=True))
225
+
226
+ return response_images
227
+
228
+
229
+ def get_prompt_attention(pipeline, prompt):
230
+ return get_weighted_text_embeddings_flux1(pipeline, prompt)
231
+
232
+
233
+ def cleanup(pipeline, loras = None):
234
+ if loras:
235
+ pipeline.unload_lora_weights()
236
+ gc.collect()
237
+ torch.cuda.empty_cache()
238
+
239
+
240
+ # Gen Function
241
+ def gen_img(request: FluxReq | FluxImg2ImgReq | FluxInpaintReq):
242
+ pipe_args = get_pipe(request)
243
+ pipeline = pipe_args["pipeline"]
244
+ try:
245
+ positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
246
+
247
+ # Common Args
248
+ args = {
249
+ 'prompt_embeds': positive_prompt_embeds,
250
+ 'pooled_prompt_embeds': positive_prompt_pooled,
251
+ 'height': request.height,
252
+ 'width': request.width,
253
+ 'num_images_per_prompt': request.num_images_per_prompt,
254
+ 'num_inference_steps': request.num_inference_steps,
255
+ 'guidance_scale': request.guidance_scale,
256
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
257
+ }
258
+
259
+ if request.controlnet_config:
260
+ args['control_mode'] = get_control_mode(request.controlnet_config)
261
+ args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
262
+ args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
263
+
264
+ if isinstance(request, (FluxImg2ImgReq, FluxInpaintReq)):
265
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
266
+ args['strength'] = request.strength
267
+
268
+ if isinstance(request, FluxInpaintReq):
269
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
270
+
271
+ # Generate
272
+ images = pipeline(**args).images
273
+
274
+ # Refiner
275
+ if request.refiner:
276
+ images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
277
+
278
+ cleanup(pipeline, request.loras)
279
+
280
+ return images
281
+ except Exception as e:
282
+ cleanup(pipeline, request.loras)
283
+ raise gr.Error(f"Error: {e}")
284
+
285
+
286
+
287
+ # CSS
288
+ css = """
289
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
290
+ body {
291
+ font-family: 'Poppins', sans-serif !important;
292
+ }
293
+ .center-content {
294
+ text-align: center;
295
+ max-width: 600px;
296
+ margin: 0 auto;
297
+ padding: 20px;
298
+ }
299
+ .center-content h1 {
300
+ font-weight: 600;
301
+ margin-bottom: 1rem;
302
+ }
303
+ .center-content p {
304
+ margin-bottom: 1.5rem;
305
+ }
306
+ """
307
+
308
+
309
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
310
+ with open("data/images/loras/flux.json", "r") as f:
311
+ loras = json.load(f)
312
+
313
+
314
+ # Event functions
315
+ def update_fast_generation(model, fast_generation):
316
+ if fast_generation:
317
+ return (
318
+ gr.update(
319
+ value=3.5
320
+ ),
321
+ gr.update(
322
+ value=8
323
+ )
324
+ )
325
+
326
+
327
+ def selected_lora_from_gallery(evt: gr.SelectData):
328
+ return (
329
+ gr.update(
330
+ value=evt.index
331
+ )
332
+ )
333
+
334
+
335
+ def update_selected_lora(custom_lora):
336
+ link = custom_lora.split("/")
337
+
338
+ if len(link) == 2:
339
+ model_card = ModelCard.load(custom_lora)
340
+ trigger_word = model_card.data.get("instance_prompt", "")
341
+ image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
342
+
343
+ custom_lora_info_css = """
344
+ <style>
345
+ .custom-lora-info {
346
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
347
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
348
+ color: white;
349
+ padding: 16px;
350
+ border-radius: 8px;
351
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
352
+ margin: 16px 0;
353
+ }
354
+ .custom-lora-header {
355
+ font-size: 18px;
356
+ font-weight: 600;
357
+ margin-bottom: 12px;
358
+ }
359
+ .custom-lora-content {
360
+ display: flex;
361
+ align-items: center;
362
+ background-color: rgba(255, 255, 255, 0.1);
363
+ border-radius: 6px;
364
+ padding: 12px;
365
+ }
366
+ .custom-lora-image {
367
+ width: 80px;
368
+ height: 80px;
369
+ object-fit: cover;
370
+ border-radius: 6px;
371
+ margin-right: 16px;
372
+ }
373
+ .custom-lora-text h3 {
374
+ margin: 0 0 8px 0;
375
+ font-size: 16px;
376
+ font-weight: 600;
377
+ }
378
+ .custom-lora-text small {
379
+ font-size: 14px;
380
+ opacity: 0.9;
381
+ }
382
+ .custom-trigger-word {
383
+ background-color: rgba(255, 255, 255, 0.2);
384
+ padding: 2px 6px;
385
+ border-radius: 4px;
386
+ font-weight: 600;
387
+ }
388
+ </style>
389
+ """
390
+
391
+ custom_lora_info_html = f"""
392
+ <div class="custom-lora-info">
393
+ <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
394
+ <div class="custom-lora-content">
395
+ <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
396
+ <div class="custom-lora-text">
397
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
398
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
399
+ </div>
400
+ </div>
401
+ </div>
402
+ """
403
+
404
+ custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
405
+
406
+ return (
407
+ gr.update( # selected_lora
408
+ value=custom_lora,
409
+ ),
410
+ gr.update( # custom_lora_info
411
+ value=custom_lora_info_html,
412
+ visible=True
413
+ )
414
+ )
415
+
416
+ else:
417
+ return (
418
+ gr.update( # selected_lora
419
+ value=custom_lora,
420
+ ),
421
+ gr.update( # custom_lora_info
422
+ value=custom_lora_info_html if len(link) == 0 else "",
423
+ visible=False
424
+ )
425
+ )
426
+
427
+
428
+ def add_to_enabled_loras(model, selected_lora, enabled_loras):
429
+ lora_data = loras
430
+ try:
431
+ selected_lora = int(selected_lora)
432
+
433
+ if 0 <= selected_lora: # is the index of the lora in the gallery
434
+ lora_info = lora_data[selected_lora]
435
+ enabled_loras.append({
436
+ "repo_id": lora_info["repo"],
437
+ "trigger_word": lora_info["trigger_word"]
438
+ })
439
+ except ValueError:
440
+ link = selected_lora.split("/")
441
+ if len(link) == 2:
442
+ model_card = ModelCard.load(selected_lora)
443
+ trigger_word = model_card.data.get("instance_prompt", "")
444
+ enabled_loras.append({
445
+ "repo_id": selected_lora,
446
+ "trigger_word": trigger_word
447
+ })
448
+
449
+ return (
450
+ gr.update( # selected_lora
451
+ value=""
452
+ ),
453
+ gr.update( # custom_lora_info
454
+ value="",
455
+ visible=False
456
+ ),
457
+ gr.update( # enabled_loras
458
+ value=enabled_loras
459
+ )
460
+ )
461
+
462
+
463
+ def update_lora_sliders(enabled_loras):
464
+ sliders = []
465
+ remove_buttons = []
466
+
467
+ for lora in enabled_loras:
468
+ sliders.append(
469
+ gr.update(
470
+ label=lora.get("repo_id", ""),
471
+ info=f"Trigger Word: {lora.get('trigger_word', '')}",
472
+ visible=True,
473
+ interactive=True
474
+ )
475
+ )
476
+ remove_buttons.append(
477
+ gr.update(
478
+ visible=True,
479
+ interactive=True
480
+ )
481
+ )
482
+
483
+ if len(sliders) < 6:
484
+ for i in range(len(sliders), 6):
485
+ sliders.append(
486
+ gr.update(
487
+ visible=False
488
+ )
489
+ )
490
+ remove_buttons.append(
491
+ gr.update(
492
+ visible=False
493
+ )
494
+ )
495
+
496
+ return *sliders, *remove_buttons
497
+
498
+
499
+ def remove_from_enabled_loras(enabled_loras, index):
500
+ enabled_loras.pop(index)
501
+ return (
502
+ gr.update(
503
+ value=enabled_loras
504
+ )
505
+ )
506
+
507
+
508
+ @spaces.GPU
509
+ def generate_image(
510
+ model, prompt, fast_generation, enabled_loras,
511
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
512
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image,
513
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength,
514
+ resize_mode,
515
+ scheduler, image_height, image_width, image_num_images_per_prompt,
516
+ image_num_inference_steps, image_guidance_scale, image_seed,
517
+ refiner, vae
518
+ ):
519
+ base_args = {
520
+ "model": model,
521
+ "prompt": prompt,
522
+ "fast_generation": fast_generation,
523
+ "loras": None,
524
+ "resize_mode": resize_mode,
525
+ "scheduler": scheduler,
526
+ "height": int(image_height),
527
+ "width": int(image_width),
528
+ "num_images_per_prompt": float(image_num_images_per_prompt),
529
+ "num_inference_steps": float(image_num_inference_steps),
530
+ "guidance_scale": float(image_guidance_scale),
531
+ "seed": int(image_seed),
532
+ "refiner": refiner,
533
+ "vae": vae,
534
+ "controlnet_config": None,
535
+ }
536
+ base_args = FluxReq(**base_args)
537
+
538
+ if len(enabled_loras) > 0:
539
+ base_args.loras = []
540
+ for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
541
+ if enabled_lora['repo_id']:
542
+ base_args.loras.append({
543
+ "repo_id": enabled_lora['repo_id'],
544
+ "weight": slider
545
+ })
546
+
547
+ image = None
548
+ mask_image = None
549
+ strength = None
550
+
551
+ if img2img_image:
552
+ image = img2img_image
553
+ strength = float(img2img_strength)
554
+
555
+ base_args = FluxImg2ImgReq(
556
+ **base_args.__dict__,
557
+ image=image,
558
+ strength=strength
559
+ )
560
+ elif inpaint_image:
561
+ image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
562
+ mask_image = inpaint_image['layers'][0] if image else None
563
+ strength = float(inpaint_strength)
564
+
565
+ if image and mask_image:
566
+ base_args = FluxInpaintReq(
567
+ **base_args.__dict__,
568
+ image=image,
569
+ mask_image=mask_image,
570
+ strength=strength
571
+ )
572
+ elif any([canny_image, pose_image, depth_image]):
573
+ base_args.controlnet_config = ControlNetReq(
574
+ controlnets=[],
575
+ control_images=[],
576
+ controlnet_conditioning_scale=[]
577
+ )
578
+
579
+ if canny_image:
580
+ base_args.controlnet_config.controlnets.append("canny")
581
+ base_args.controlnet_config.control_images.append(canny_image)
582
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
583
+ if pose_image:
584
+ base_args.controlnet_config.controlnets.append("pose")
585
+ base_args.controlnet_config.control_images.append(pose_image)
586
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
587
+ if depth_image:
588
+ base_args.controlnet_config.controlnets.append("depth")
589
+ base_args.controlnet_config.control_images.append(depth_image)
590
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
591
+ else:
592
+ base_args = FluxReq(**base_args.__dict__)
593
+
594
+ return gr.update(
595
+ value=gen_img(base_args),
596
+ interactive=True
597
+ )
598
+
599
+
600
+ # Main Gradio app
601
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
602
+ # Header
603
+ with gr.Column(elem_classes="center-content"):
604
+ gr.Markdown("""
605
+ # πŸš€ AAI: All AI
606
+ Unleash your creativity with our multi-modal AI platform.
607
+ [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
608
+ """)
609
+
610
+ # Tabs
611
+ with gr.Tabs():
612
+ with gr.Tab(label="πŸ–ΌοΈ Image"):
613
+ with gr.Tabs():
614
+ with gr.Tab("Flux"):
615
+ """
616
+ Create the image tab for Generative Image Generation Models
617
+
618
+ Args:
619
+ models: list
620
+ A list containing the models repository paths
621
+ gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
622
+ A list of dictionaries containing the title and component for the custom gradio component
623
+ Example:
624
+ def gr_comp():
625
+ gr.Label("Hello World")
626
+
627
+ [
628
+ {
629
+ 'title': "Title",
630
+ 'component': gr_comp()
631
+ }
632
+ ]
633
+ loras: list
634
+ A list of dictionaries containing the image and title for the Loras Gallery
635
+ Generally a loaded json file from the data folder
636
+
637
+ """
638
+ def process_gaps(gaps: List[dict]):
639
+ for gap in gaps:
640
+ with gr.Accordion(gap['title']):
641
+ gap['component']
642
+
643
+
644
+ with gr.Row():
645
+ with gr.Column():
646
+ with gr.Group() as image_options:
647
+ model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
648
+ prompt = gr.Textbox(lines=5, label="Prompt")
649
+ fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
650
+
651
+
652
+ with gr.Accordion("Loras", open=True): # Lora Gallery
653
+ lora_gallery = gr.Gallery(
654
+ label="Gallery",
655
+ value=[(lora['image'], lora['title']) for lora in loras],
656
+ allow_preview=False,
657
+ columns=3,
658
+ rows=3,
659
+ type="pil"
660
+ )
661
+
662
+ with gr.Group():
663
+ with gr.Column():
664
+ with gr.Row():
665
+ custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
666
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
667
+
668
+ custom_lora_info = gr.HTML(visible=False)
669
+ add_lora = gr.Button(value="Add LoRA")
670
+
671
+ enabled_loras = gr.State(value=[])
672
+ with gr.Group():
673
+ with gr.Row():
674
+ for i in range(6): # only support max 6 loras due to inference time
675
+ with gr.Column():
676
+ with gr.Column(scale=2):
677
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
678
+ with gr.Column():
679
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
680
+
681
+
682
+ with gr.Accordion("Embeddings", open=False): # Embeddings
683
+ gr.Label("To be implemented")
684
+
685
+
686
+ with gr.Accordion("Image Options"): # Image Options
687
+ with gr.Tabs():
688
+ image_options = {
689
+ "img2img": "Upload Image",
690
+ "inpaint": "Upload Image",
691
+ "canny": "Upload Image",
692
+ "pose": "Upload Image",
693
+ "depth": "Upload Image",
694
+ }
695
+
696
+ for image_option, label in image_options.items():
697
+ with gr.Tab(image_option):
698
+ if not image_option in ['inpaint', 'scribble']:
699
+ globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
700
+ elif image_option in ['inpaint', 'scribble']:
701
+ globals()[f"{image_option}_image"] = gr.ImageEditor(
702
+ label=label,
703
+ image_mode='RGB',
704
+ layers=False,
705
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
706
+ interactive=True,
707
+ type="pil",
708
+ )
709
+
710
+ # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
711
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
712
+
713
+ resize_mode = gr.Radio(
714
+ label="Resize Mode",
715
+ choices=["crop and resize", "resize only", "resize and fill"],
716
+ value="resize and fill",
717
+ interactive=True
718
+ )
719
+
720
+
721
+ with gr.Column():
722
+ with gr.Group():
723
+ output_images = gr.Gallery(
724
+ label="Output Images",
725
+ value=[],
726
+ allow_preview=True,
727
+ type="pil",
728
+ interactive=False,
729
+ )
730
+ generate_images = gr.Button(value="Generate Images", variant="primary")
731
+
732
+ with gr.Accordion("Advance Settings", open=True):
733
+ with gr.Row():
734
+ scheduler = gr.Dropdown(
735
+ label="Scheduler",
736
+ choices = [
737
+ "fm_euler"
738
+ ],
739
+ value="fm_euler",
740
+ interactive=True
741
+ )
742
+
743
+ with gr.Row():
744
+ for column in range(2):
745
+ with gr.Column():
746
+ options = [
747
+ ("Height", "image_height", 64, 1024, 64, 1024, True),
748
+ ("Width", "image_width", 64, 1024, 64, 1024, True),
749
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
750
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
751
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
752
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
753
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
754
+ ]
755
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
756
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
757
+
758
+ with gr.Row():
759
+ refiner = gr.Checkbox(
760
+ label="Refiner πŸ§ͺ",
761
+ value=False,
762
+ )
763
+ vae = gr.Checkbox(
764
+ label="VAE",
765
+ value=True,
766
+ )
767
+
768
+
769
+ # Events
770
+ # Base Options
771
+ fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
772
+
773
+
774
+ # Lora Gallery
775
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
776
+ custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
777
+ add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
778
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
779
+
780
+ for i in range(6):
781
+ globals()[f"lora_remove_{i}"].click(
782
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
783
+ [enabled_loras],
784
+ [enabled_loras]
785
+ )
786
+
787
+
788
+ # Generate Image
789
+ generate_images.click(
790
+ generate_image, # type: ignore
791
+ [
792
+ model, prompt, fast_generation, enabled_loras,
793
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
794
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
795
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
796
+ resize_mode,
797
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
798
+ image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
799
+ refiner, vae
800
+ ],
801
+ [output_images]
802
+ )
803
+ with gr.Tab("SDXL"):
804
+ gr.Label("To be implemented")
805
+ with gr.Tab(label="🎡 Audio"):
806
+ gr.Label("Coming soon!")
807
+ with gr.Tab(label="🎬 Video"):
808
+ gr.Label("Coming soon!")
809
+ with gr.Tab(label="πŸ“„ Text"):
810
+ gr.Label("Coming soon!")
811
+
812
+
813
+ demo.launch(
814
+ share=False,
815
+ debug=True,
816
+ )
app2.py β†’ old/app2.py RENAMED
File without changes
app3.py β†’ old/app3.py RENAMED
File without changes
{src β†’ old/src}/tasks/images/init_sys.py RENAMED
File without changes
{src β†’ old/src}/tasks/images/sd.py RENAMED
File without changes
{src β†’ old/src}/ui/__init__.py RENAMED
File without changes
src/ui/videos.py β†’ old/src/ui/audios.py RENAMED
File without changes
{src β†’ old/src}/ui/images.py RENAMED
File without changes
{src β†’ old/src}/ui/tabs/__init__.py RENAMED
File without changes
{src β†’ old/src}/ui/tabs/images/flux.py RENAMED
File without changes
old/src/ui/talkinghead.py ADDED
File without changes
old/src/ui/texts.py ADDED
File without changes
old/src/ui/videos.py ADDED
File without changes
tabs/audio_tab.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def audio_tab():
5
+ gr.Label("Coming soon...")
tabs/image_tab.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tabs/image_tab.py
2
+
3
+ import gradio as gr
4
+ from modules.events.flux_events import *
5
+ from modules.events.sdxl_events import *
6
+ from modules.helpers.common_helpers import *
7
+ from modules.helpers.flux_helpers import *
8
+ from modules.helpers.sdxl_helpers import *
9
+ from config import flux_models, sdxl_models, loras
10
+
11
+
12
+ def image_tab():
13
+ with gr.Tab(label="πŸ–ΌοΈ Image"):
14
+ with gr.Tabs():
15
+ flux_tab()
16
+ sdxl_tab()
17
+
18
+
19
+ def flux_tab():
20
+ with gr.Tab("Flux"):
21
+ with gr.Row():
22
+ with gr.Column():
23
+ with gr.Group() as image_options:
24
+ model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
25
+ prompt = gr.Textbox(lines=5, label="Prompt")
26
+ fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
27
+
28
+
29
+ with gr.Accordion("Loras", open=True): # Lora Gallery
30
+ lora_gallery = gr.Gallery(
31
+ label="Gallery",
32
+ value=[(lora['image'], lora['title']) for lora in loras],
33
+ allow_preview=False,
34
+ columns=3,
35
+ rows=3,
36
+ type="pil"
37
+ )
38
+
39
+ with gr.Group():
40
+ with gr.Column():
41
+ with gr.Row():
42
+ custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
43
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
44
+
45
+ custom_lora_info = gr.HTML(visible=False)
46
+ add_lora = gr.Button(value="Add LoRA")
47
+
48
+ enabled_loras = gr.State(value=[])
49
+ with gr.Group():
50
+ with gr.Row():
51
+ for i in range(6): # only support max 6 loras due to inference time
52
+ with gr.Column():
53
+ with gr.Column(scale=2):
54
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
55
+ with gr.Column():
56
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
57
+
58
+
59
+ with gr.Accordion("Embeddings", open=False): # Embeddings
60
+ gr.Label("To be implemented")
61
+
62
+
63
+ with gr.Accordion("Image Options"): # Image Options
64
+ with gr.Tabs():
65
+ image_options = {
66
+ "img2img": "Upload Image",
67
+ "inpaint": "Upload Image",
68
+ "canny": "Upload Image",
69
+ "pose": "Upload Image",
70
+ "depth": "Upload Image",
71
+ }
72
+
73
+ for image_option, label in image_options.items():
74
+ with gr.Tab(image_option):
75
+ if not image_option in ['inpaint', 'scribble']:
76
+ globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
77
+ elif image_option in ['inpaint', 'scribble']:
78
+ globals()[f"{image_option}_image"] = gr.ImageEditor(
79
+ label=label,
80
+ image_mode='RGB',
81
+ layers=False,
82
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
83
+ interactive=True,
84
+ type="pil",
85
+ )
86
+
87
+ # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
88
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
89
+
90
+ resize_mode = gr.Radio(
91
+ label="Resize Mode",
92
+ choices=["crop and resize", "resize only", "resize and fill"],
93
+ value="resize and fill",
94
+ interactive=True
95
+ )
96
+
97
+
98
+ with gr.Column():
99
+ with gr.Group():
100
+ output_images = gr.Gallery(
101
+ label="Output Images",
102
+ value=[],
103
+ allow_preview=True,
104
+ type="pil",
105
+ interactive=False,
106
+ )
107
+ generate_images = gr.Button(value="Generate Images", variant="primary")
108
+
109
+ with gr.Accordion("Advance Settings", open=True):
110
+ with gr.Row():
111
+ scheduler = gr.Dropdown(
112
+ label="Scheduler",
113
+ choices = [
114
+ "fm_euler"
115
+ ],
116
+ value="fm_euler",
117
+ interactive=True
118
+ )
119
+
120
+ with gr.Row():
121
+ for column in range(2):
122
+ with gr.Column():
123
+ options = [
124
+ ("Height", "image_height", 64, 1024, 64, 1024, True),
125
+ ("Width", "image_width", 64, 1024, 64, 1024, True),
126
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
127
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
128
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
129
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
130
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
131
+ ]
132
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
133
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
134
+
135
+ with gr.Row():
136
+ refiner = gr.Checkbox(
137
+ label="Refiner πŸ§ͺ",
138
+ value=False,
139
+ )
140
+ vae = gr.Checkbox(
141
+ label="VAE",
142
+ value=True,
143
+ )
144
+
145
+ # Events
146
+ # Base Options
147
+ fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
148
+
149
+
150
+ # Lora Gallery
151
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
152
+ custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
153
+ add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
154
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
155
+
156
+ for i in range(6):
157
+ globals()[f"lora_remove_{i}"].click(
158
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
159
+ [enabled_loras],
160
+ [enabled_loras]
161
+ )
162
+
163
+
164
+ # Generate Image
165
+ generate_images.click(
166
+ generate_image, # type: ignore
167
+ [
168
+ model, prompt, fast_generation, enabled_loras,
169
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
170
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
171
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
172
+ resize_mode,
173
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
174
+ image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
175
+ refiner, vae
176
+ ],
177
+ [output_images]
178
+ )
179
+
180
+
181
+ def sdxl_tab():
182
+ with gr.Tab("SDXL"):
183
+ gr.Label("To be implemented")
tabs/text_tab.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def text_tab():
5
+ gr.Label("Coming soon...")
tabs/video_tab.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def video_tab():
5
+ gr.Label("Coming soon...")