multimodalart HF staff commited on
Commit
4f5b1e9
·
verified ·
1 Parent(s): 16490f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -91
app.py CHANGED
@@ -25,17 +25,18 @@ base_model = "black-forest-labs/FLUX.1-dev"
25
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
26
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
28
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model,
29
- vae=good_vae,
30
- transformer=pipe.transformer,
31
- text_encoder=pipe.text_encoder,
32
- tokenizer=pipe.tokenizer,
33
- text_encoder_2=pipe.text_encoder_2,
34
- tokenizer_2=pipe.tokenizer_2,
35
- torch_dtype=dtype
36
- )
37
-
38
- MAX_SEED = 2**32-1
 
39
 
40
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
41
 
@@ -46,7 +47,7 @@ class calculateDuration:
46
  def __enter__(self):
47
  self.start_time = time.time()
48
  return self
49
-
50
  def __exit__(self, exc_type, exc_value, traceback):
51
  self.end_time = time.time()
52
  self.elapsed_time = self.end_time - self.start_time
@@ -66,7 +67,7 @@ def update_selection(evt: gr.SelectData, selected_indices, width, height):
66
  selected_indices.append(selected_index)
67
  else:
68
  gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
69
- return(
70
  gr.update(),
71
  gr.update(),
72
  gr.update(),
@@ -80,19 +81,19 @@ def update_selection(evt: gr.SelectData, selected_indices, width, height):
80
  )
81
 
82
  # Initialize outputs
83
- selected_info_1 = ""
84
- selected_info_2 = ""
85
  lora_scale_1 = 0.95
86
  lora_scale_2 = 0.95
87
  lora_image_1 = None
88
  lora_image_2 = None
89
  if len(selected_indices) >= 1:
90
  lora1 = loras[selected_indices[0]]
91
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
92
  lora_image_1 = lora1['image']
93
  if len(selected_indices) >= 2:
94
  lora2 = loras[selected_indices[1]]
95
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
96
  lora_image_2 = lora2['image']
97
 
98
  # Update prompt placeholder based on last selected LoRA
@@ -128,11 +129,11 @@ def remove_lora_1(selected_indices):
128
  lora_image_2 = None
129
  if len(selected_indices) >= 1:
130
  lora1 = loras[selected_indices[0]]
131
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
132
  lora_image_1 = lora1['image']
133
  if len(selected_indices) >= 2:
134
  lora2 = loras[selected_indices[1]]
135
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
136
  lora_image_2 = lora2['image']
137
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
138
 
@@ -149,11 +150,11 @@ def remove_lora_2(selected_indices):
149
  lora_image_2 = None
150
  if len(selected_indices) >= 1:
151
  lora1 = loras[selected_indices[0]]
152
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
153
  lora_image_1 = lora1['image']
154
  if len(selected_indices) >= 2:
155
  lora2 = loras[selected_indices[1]]
156
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
157
  lora_image_2 = lora2['image']
158
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
159
 
@@ -163,8 +164,8 @@ def randomize_loras(selected_indices):
163
  selected_indices = random.sample(range(len(loras)), 2)
164
  lora1 = loras[selected_indices[0]]
165
  lora2 = loras[selected_indices[1]]
166
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
167
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
168
  lora_scale_1 = 0.95
169
  lora_scale_2 = 0.95
170
  lora_image_1 = lora1['image']
@@ -173,7 +174,7 @@ def randomize_loras(selected_indices):
173
 
174
  @spaces.GPU(duration=70)
175
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
176
- print("Entrou aqui!")
177
  pipe.to("cuda")
178
  generator = torch.Generator(device="cuda").manual_seed(seed)
179
  with calculateDuration("Generating image"):
@@ -208,8 +209,8 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
208
  joint_attention_kwargs={"scale": 1.0},
209
  output_type="pil",
210
  ).images[0]
211
- return final_image
212
-
213
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, progress=gr.Progress(track_tqdm=True)):
214
  if not selected_indices:
215
  raise gr.Error("You must select at least one LoRA before proceeding.")
@@ -235,27 +236,29 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
235
 
236
  # Load LoRA weights with respective scales
237
  lora_names = []
 
238
  with calculateDuration("Loading LoRA weights"):
239
  for idx, lora in enumerate(selected_loras):
240
  lora_name = f"lora_{idx}"
241
  lora_names.append(lora_name)
 
242
  lora_path = lora['repo']
243
- scale = lora_scale_1 if idx == 0 else lora_scale_2
244
  if image_input is not None:
245
- if "weights" in lora:
246
- pipe_i2i.load_lora_weights(lora_path, weight_name=lora["weights"], low_cpu_mem_usage=True, adapter_name=lora_name)
247
  else:
248
  pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
249
  else:
250
- if "weights" in lora:
251
- pipe.load_lora_weights(lora_path, weight_name=lora["weights"], low_cpu_mem_usage=True, adapter_name=lora_name)
252
  else:
253
  pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
254
- print(lora_names)
255
  if image_input is not None:
256
- pipe_i2i.set_adapters(lora_names, adapter_weights=[lora_scale_1, lora_scale_2])
257
  else:
258
- pipe.set_adapters(lora_names, adapter_weights=[lora_scale_1, lora_scale_2])
259
  # Set random seed for reproducibility
260
  with calculateDuration("Randomizing seed"):
261
  if randomize_seed:
@@ -271,7 +274,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
271
  final_image = None
272
  step_counter = 0
273
  for image in image_generator:
274
- step_counter+=1
275
  final_image = image
276
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
277
  yield image, seed, gr.update(value=progress_bar, visible=True)
@@ -282,7 +285,7 @@ def get_huggingface_safetensors(link):
282
  if len(split_link) == 2:
283
  model_card = ModelCard.load(link)
284
  base_model = model_card.data.get("base_model")
285
- print(base_model)
286
  if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
287
  raise Exception("Not a FLUX LoRA!")
288
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
@@ -304,13 +307,26 @@ def get_huggingface_safetensors(link):
304
  if not safetensors_name:
305
  raise Exception("No *.safetensors file found in the repository")
306
  return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
307
 
308
  def check_custom_model(link):
309
- if link.startswith("https://"):
310
- if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
 
 
 
 
 
 
 
 
311
  link_split = link.split("huggingface.co/")
312
  return get_huggingface_safetensors(link_split[1])
313
- else:
 
 
 
314
  return get_huggingface_safetensors(link)
315
 
316
  def add_custom_lora(custom_lora, selected_indices):
@@ -319,18 +335,6 @@ def add_custom_lora(custom_lora, selected_indices):
319
  try:
320
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
321
  print(f"Loaded custom LoRA: {repo}")
322
- card = f'''
323
- <div class="custom_lora_card">
324
- <span>Loaded custom LoRA:</span>
325
- <div class="card_internal">
326
- <img src="{image}" />
327
- <div>
328
- <h3>{title}</h3>
329
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
330
- </div>
331
- </div>
332
- </div>
333
- '''
334
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
335
  if existing_item_index is None:
336
  new_item = {
@@ -340,7 +344,7 @@ def add_custom_lora(custom_lora, selected_indices):
340
  "weights": path,
341
  "trigger_word": trigger_word
342
  }
343
- print(new_item)
344
  existing_item_index = len(loras)
345
  loras.append(new_item)
346
 
@@ -349,32 +353,42 @@ def add_custom_lora(custom_lora, selected_indices):
349
  # Update selected_indices if there's room
350
  if len(selected_indices) < 2:
351
  selected_indices.append(existing_item_index)
352
- selected_info_1 = ""
353
- selected_info_2 = ""
354
- lora_scale_1 = 0.95
355
- lora_scale_2 = 0.95
356
- lora_image_1 = None
357
- lora_image_2 = None
358
- if len(selected_indices) >= 1:
359
- lora1 = loras[selected_indices[0]]
360
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
361
- lora_image_1 = lora1['image']
362
- if len(selected_indices) >= 2:
363
- lora2 = loras[selected_indices[1]]
364
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
365
- lora_image_2 = lora2['image']
366
- return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
367
- selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2)
368
  else:
369
- return (gr.update(visible=True, value=card), gr.update(visible=True), gr.update(value=gallery_items),
370
- gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  except Exception as e:
372
  print(e)
373
- return gr.update(visible=True, value=str(e)), gr.update(visible=True), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
 
374
  else:
375
- return gr.update(visible=False), gr.update(visible=False), gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
376
 
377
- def remove_custom_lora(custom_lora_info, custom_lora_button, selected_indices):
378
  global loras
379
  if loras:
380
  custom_lora_repo = loras[-1]['repo']
@@ -387,21 +401,30 @@ def remove_custom_lora(custom_lora_info, custom_lora_button, selected_indices):
387
  # Update gallery
388
  gallery_items = [(item["image"], item["title"]) for item in loras]
389
  # Update selected_info and images
390
- selected_info_1 = ""
391
- selected_info_2 = ""
392
  lora_scale_1 = 0.95
393
  lora_scale_2 = 0.95
394
  lora_image_1 = None
395
  lora_image_2 = None
396
  if len(selected_indices) >= 1:
397
  lora1 = loras[selected_indices[0]]
398
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
399
  lora_image_1 = lora1['image']
400
  if len(selected_indices) >= 2:
401
  lora2 = loras[selected_indices[1]]
402
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
403
  lora_image_2 = lora2['image']
404
- return gr.update(visible=False), gr.update(visible=False), gr.update(value=gallery_items), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
 
 
 
 
 
 
 
 
 
405
 
406
  run_lora.zerogpu = True
407
 
@@ -410,6 +433,7 @@ css = '''
410
  #title{text-align: center}
411
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
412
  #title img{width: 100px; margin-right: 0.5em}
 
413
  #gallery .grid-wrap{height: 5vh}
414
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
415
  .custom_lora_card{margin-bottom: 1em}
@@ -460,6 +484,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
460
  remove_button_2 = gr.Button("Remove", size="sm")
461
  with gr.Row():
462
  with gr.Column():
 
 
 
 
 
463
  gallery = gr.Gallery(
464
  [(item["image"], item["title"]) for item in loras],
465
  label="LoRA Gallery",
@@ -467,11 +496,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
467
  columns=5,
468
  elem_id="gallery"
469
  )
470
- with gr.Group():
471
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
472
- gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
473
- custom_lora_info = gr.HTML(visible=False)
474
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
475
  with gr.Column():
476
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
477
  result = gr.Image(label="Generated Image")
@@ -484,15 +508,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
484
  with gr.Row():
485
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
486
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
487
-
488
  with gr.Row():
489
  width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
490
  height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
491
-
492
  with gr.Row():
493
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
494
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
495
-
496
  gallery.select(
497
  update_selection,
498
  inputs=[selected_indices, width, height],
@@ -513,15 +537,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 3600)) as app:
513
  inputs=[selected_indices],
514
  outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
515
  )
516
- custom_lora.change(
517
  add_custom_lora,
518
  inputs=[custom_lora, selected_indices],
519
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
520
  )
521
- custom_lora_button.click(
522
  remove_custom_lora,
523
- inputs=[custom_lora_info, custom_lora_button, selected_indices],
524
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
525
  )
526
  gr.on(
527
  triggers=[generate_button.click, prompt.submit],
 
25
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
26
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
27
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
28
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
29
+ base_model,
30
+ vae=good_vae,
31
+ transformer=pipe.transformer,
32
+ text_encoder=pipe.text_encoder,
33
+ tokenizer=pipe.tokenizer,
34
+ text_encoder_2=pipe.text_encoder_2,
35
+ tokenizer_2=pipe.tokenizer_2,
36
+ torch_dtype=dtype
37
+ )
38
+
39
+ MAX_SEED = 2**32 - 1
40
 
41
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
42
 
 
47
  def __enter__(self):
48
  self.start_time = time.time()
49
  return self
50
+
51
  def __exit__(self, exc_type, exc_value, traceback):
52
  self.end_time = time.time()
53
  self.elapsed_time = self.end_time - self.start_time
 
67
  selected_indices.append(selected_index)
68
  else:
69
  gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
70
+ return (
71
  gr.update(),
72
  gr.update(),
73
  gr.update(),
 
81
  )
82
 
83
  # Initialize outputs
84
+ selected_info_1 = "Select a LoRA 1"
85
+ selected_info_2 = "Select a LoRA 2"
86
  lora_scale_1 = 0.95
87
  lora_scale_2 = 0.95
88
  lora_image_1 = None
89
  lora_image_2 = None
90
  if len(selected_indices) >= 1:
91
  lora1 = loras[selected_indices[0]]
92
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
93
  lora_image_1 = lora1['image']
94
  if len(selected_indices) >= 2:
95
  lora2 = loras[selected_indices[1]]
96
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
97
  lora_image_2 = lora2['image']
98
 
99
  # Update prompt placeholder based on last selected LoRA
 
129
  lora_image_2 = None
130
  if len(selected_indices) >= 1:
131
  lora1 = loras[selected_indices[0]]
132
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
133
  lora_image_1 = lora1['image']
134
  if len(selected_indices) >= 2:
135
  lora2 = loras[selected_indices[1]]
136
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
137
  lora_image_2 = lora2['image']
138
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
139
 
 
150
  lora_image_2 = None
151
  if len(selected_indices) >= 1:
152
  lora1 = loras[selected_indices[0]]
153
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
154
  lora_image_1 = lora1['image']
155
  if len(selected_indices) >= 2:
156
  lora2 = loras[selected_indices[1]]
157
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
158
  lora_image_2 = lora2['image']
159
  return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
160
 
 
164
  selected_indices = random.sample(range(len(loras)), 2)
165
  lora1 = loras[selected_indices[0]]
166
  lora2 = loras[selected_indices[1]]
167
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
168
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
169
  lora_scale_1 = 0.95
170
  lora_scale_2 = 0.95
171
  lora_image_1 = lora1['image']
 
174
 
175
  @spaces.GPU(duration=70)
176
  def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
177
+ print("Generating image...")
178
  pipe.to("cuda")
179
  generator = torch.Generator(device="cuda").manual_seed(seed)
180
  with calculateDuration("Generating image"):
 
209
  joint_attention_kwargs={"scale": 1.0},
210
  output_type="pil",
211
  ).images[0]
212
+ return final_image
213
+
214
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, progress=gr.Progress(track_tqdm=True)):
215
  if not selected_indices:
216
  raise gr.Error("You must select at least one LoRA before proceeding.")
 
236
 
237
  # Load LoRA weights with respective scales
238
  lora_names = []
239
+ lora_weights = []
240
  with calculateDuration("Loading LoRA weights"):
241
  for idx, lora in enumerate(selected_loras):
242
  lora_name = f"lora_{idx}"
243
  lora_names.append(lora_name)
244
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
245
  lora_path = lora['repo']
246
+ weight_name = lora.get("weights")
247
  if image_input is not None:
248
+ if weight_name:
249
+ pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
250
  else:
251
  pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
252
  else:
253
+ if weight_name:
254
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
255
  else:
256
  pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
257
+ print("Loaded LoRAs:", lora_names)
258
  if image_input is not None:
259
+ pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
260
  else:
261
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
262
  # Set random seed for reproducibility
263
  with calculateDuration("Randomizing seed"):
264
  if randomize_seed:
 
274
  final_image = None
275
  step_counter = 0
276
  for image in image_generator:
277
+ step_counter += 1
278
  final_image = image
279
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
280
  yield image, seed, gr.update(value=progress_bar, visible=True)
 
285
  if len(split_link) == 2:
286
  model_card = ModelCard.load(link)
287
  base_model = model_card.data.get("base_model")
288
+ print(f"Base model: {base_model}")
289
  if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
290
  raise Exception("Not a FLUX LoRA!")
291
  image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
 
307
  if not safetensors_name:
308
  raise Exception("No *.safetensors file found in the repository")
309
  return split_link[1], link, safetensors_name, trigger_word, image_url
310
+ else:
311
+ raise Exception("Invalid Hugging Face repository link")
312
 
313
  def check_custom_model(link):
314
+ if link.endswith(".safetensors"):
315
+ # Treat as direct link to the LoRA weights
316
+ title = os.path.basename(link)
317
+ repo = link
318
+ path = None # No specific weight name
319
+ trigger_word = ""
320
+ image_url = None
321
+ return title, repo, path, trigger_word, image_url
322
+ elif link.startswith("https://"):
323
+ if "huggingface.co" in link:
324
  link_split = link.split("huggingface.co/")
325
  return get_huggingface_safetensors(link_split[1])
326
+ else:
327
+ raise Exception("Unsupported URL")
328
+ else:
329
+ # Assume it's a Hugging Face model path
330
  return get_huggingface_safetensors(link)
331
 
332
  def add_custom_lora(custom_lora, selected_indices):
 
335
  try:
336
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
337
  print(f"Loaded custom LoRA: {repo}")
 
 
 
 
 
 
 
 
 
 
 
 
338
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
339
  if existing_item_index is None:
340
  new_item = {
 
344
  "weights": path,
345
  "trigger_word": trigger_word
346
  }
347
+ print(f"New LoRA: {new_item}")
348
  existing_item_index = len(loras)
349
  loras.append(new_item)
350
 
 
353
  # Update selected_indices if there's room
354
  if len(selected_indices) < 2:
355
  selected_indices.append(existing_item_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  else:
357
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
358
+
359
+ # Update selected_info and images
360
+ selected_info_1 = "Select a LoRA 1"
361
+ selected_info_2 = "Select a LoRA 2"
362
+ lora_scale_1 = 0.95
363
+ lora_scale_2 = 0.95
364
+ lora_image_1 = None
365
+ lora_image_2 = None
366
+ if len(selected_indices) >= 1:
367
+ lora1 = loras[selected_indices[0]]
368
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
369
+ lora_image_1 = lora1['image']
370
+ if len(selected_indices) >= 2:
371
+ lora2 = loras[selected_indices[1]]
372
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
373
+ lora_image_2 = lora2['image']
374
+ return (
375
+ gr.update(value=gallery_items),
376
+ selected_info_1,
377
+ selected_info_2,
378
+ selected_indices,
379
+ lora_scale_1,
380
+ lora_scale_2,
381
+ lora_image_1,
382
+ lora_image_2
383
+ )
384
  except Exception as e:
385
  print(e)
386
+ gr.Error(str(e))
387
+ return gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
388
  else:
389
+ return gr.NoChange(), gr.NoChange(), gr.NoChange(), selected_indices, gr.NoChange(), gr.NoChange(), gr.NoChange(), gr.NoChange()
390
 
391
+ def remove_custom_lora(selected_indices):
392
  global loras
393
  if loras:
394
  custom_lora_repo = loras[-1]['repo']
 
401
  # Update gallery
402
  gallery_items = [(item["image"], item["title"]) for item in loras]
403
  # Update selected_info and images
404
+ selected_info_1 = "Select a LoRA 1"
405
+ selected_info_2 = "Select a LoRA 2"
406
  lora_scale_1 = 0.95
407
  lora_scale_2 = 0.95
408
  lora_image_1 = None
409
  lora_image_2 = None
410
  if len(selected_indices) >= 1:
411
  lora1 = loras[selected_indices[0]]
412
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
413
  lora_image_1 = lora1['image']
414
  if len(selected_indices) >= 2:
415
  lora2 = loras[selected_indices[1]]
416
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
417
  lora_image_2 = lora2['image']
418
+ return (
419
+ gr.update(value=gallery_items),
420
+ selected_info_1,
421
+ selected_info_2,
422
+ selected_indices,
423
+ lora_scale_1,
424
+ lora_scale_2,
425
+ lora_image_1,
426
+ lora_image_2
427
+ )
428
 
429
  run_lora.zerogpu = True
430
 
 
433
  #title{text-align: center}
434
  #title h1{font-size: 3em; display:inline-flex; align-items:center}
435
  #title img{width: 100px; margin-right: 0.5em}
436
+ #gallery{height: 260px}
437
  #gallery .grid-wrap{height: 5vh}
438
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
439
  .custom_lora_card{margin-bottom: 1em}
 
484
  remove_button_2 = gr.Button("Remove", size="sm")
485
  with gr.Row():
486
  with gr.Column():
487
+ with gr.Group():
488
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux")
489
+ add_custom_lora_button = gr.Button("Add Custom LoRA")
490
+ remove_custom_lora_button = gr.Button("Remove Custom LoRA")
491
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
492
  gallery = gr.Gallery(
493
  [(item["image"], item["title"]) for item in loras],
494
  label="LoRA Gallery",
 
496
  columns=5,
497
  elem_id="gallery"
498
  )
 
 
 
 
 
499
  with gr.Column():
500
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
501
  result = gr.Image(label="Generated Image")
 
508
  with gr.Row():
509
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
510
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
511
+
512
  with gr.Row():
513
  width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
514
  height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
515
+
516
  with gr.Row():
517
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
518
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
519
+
520
  gallery.select(
521
  update_selection,
522
  inputs=[selected_indices, width, height],
 
537
  inputs=[selected_indices],
538
  outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
539
  )
540
+ add_custom_lora_button.click(
541
  add_custom_lora,
542
  inputs=[custom_lora, selected_indices],
543
+ outputs=[gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
544
  )
545
+ remove_custom_lora_button.click(
546
  remove_custom_lora,
547
+ inputs=[selected_indices],
548
+ outputs=[gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
549
  )
550
  gr.on(
551
  triggers=[generate_button.click, prompt.submit],