jiuface commited on
Commit
ae1ab67
1 Parent(s): ec07421
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +100 -134
  3. requirements.txt +1 -1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Flux Inpaint
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
1
  ---
2
+ title: Flux-dev controlnet inpainting with lora
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
app.py CHANGED
@@ -22,42 +22,36 @@ from diffusers.utils import load_image, make_image_grid
22
 
23
  import json
24
  from preprocessor import Preprocessor
25
- from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import FluxControlNetInpaintPipeline
26
- from diffusers.models.controlnet_flux import FluxControlNetModel
27
- from diffusers.models import FluxMultiControlNetModel
28
 
29
  HF_TOKEN = os.environ.get("HF_TOKEN")
30
 
31
  login(token=HF_TOKEN)
32
 
33
  MAX_SEED = np.iinfo(np.int32).max
34
- IMAGE_SIZE = 768
35
 
36
  # init
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
  base_model = "black-forest-labs/FLUX.1-dev"
39
 
40
- controlnet_model = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
41
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
42
- controlnet = FluxMultiControlNetModel([controlnet])
43
 
44
 
45
  pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
46
- torch.backends.cuda.matmul.allow_tf32 = True
47
- pipe.vae.enable_tiling()
48
- pipe.vae.enable_slicing()
49
  # pipe.enable_model_cpu_offload() # for saving memory
50
 
51
  control_mode_ids = {
52
- "scribble_hed": 0,
53
  "canny": 0, # supported
54
- "mlsd": 0, # supported
55
  "tile": 1, # supported
56
- "depth_midas": 2, # supported
57
  "blur": 3, # supported
58
- "openpose": 4, # supported
59
  "gray": 5, # supported
60
- "low_quality": 6, # supported
61
  }
62
 
63
  def clear_cuda_cache():
@@ -126,37 +120,36 @@ def process_mask(
126
  return mask
127
 
128
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
129
- print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
130
- connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
131
-
132
- s3 = boto3.client(
133
- 's3',
134
- endpoint_url=connectionUrl,
135
- region_name='auto',
136
- aws_access_key_id=access_key,
137
- aws_secret_access_key=secret_key
138
- )
 
139
 
140
- current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
141
- image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
142
- buffer = BytesIO()
143
- image.save(buffer, "PNG")
144
- buffer.seek(0)
145
- s3.upload_fileobj(buffer, bucket_name, image_file)
146
- print("upload finish", image_file)
147
-
148
  return image_file
149
 
150
-
 
151
  def run_flux(
152
  image: Image.Image,
153
  mask: Image.Image,
154
  control_image: Image.Image,
155
  control_mode: int,
156
  prompt: str,
157
- lora_path: str,
158
- lora_weights: str,
159
- lora_scale: float,
160
  seed_slicer: int,
161
  randomize_seed_checkbox: bool,
162
  strength_slider: float,
@@ -165,12 +158,6 @@ def run_flux(
165
  progress
166
  ) -> Image.Image:
167
  print("Running FLUX...")
168
- clear_cuda_cache()
169
- if lora_path and lora_weights:
170
- with calculateDuration("load lora"):
171
- print("start to load lora", lora_path, lora_weights)
172
- pipe.load_lora_weights(lora_path, weight_name=lora_weights)
173
-
174
  width, height = resolution_wh
175
  if randomize_seed_checkbox:
176
  seed_slicer = random.randint(0, MAX_SEED)
@@ -184,22 +171,73 @@ def run_flux(
184
  prompt=prompt,
185
  image=image,
186
  mask_image=mask,
187
- control_image=[control_image],
188
- control_mode=[control_mode],
189
  controlnet_conditioning_scale=[0.55],
190
  width=width,
191
  height=height,
192
  strength=strength_slider,
193
  generator=generator,
194
  num_inference_steps=num_inference_steps_slider,
195
- # max_sequence_length=256,
196
- joint_attention_kwargs={"scale": lora_scale}
197
  ).images[0]
198
  progress(99, "Generate image success!")
199
  return generated_image
200
 
201
- @spaces.GPU(duration=120)
202
- @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  def process(
205
  image_url: str,
@@ -212,9 +250,7 @@ def process(
212
  randomize_seed_checkbox: bool,
213
  strength_slider: float,
214
  num_inference_steps_slider: int,
215
- lora_path: str,
216
- lora_weights: str,
217
- lora_scale: str,
218
  upload_to_r2: bool,
219
  account_id: str,
220
  access_key: str,
@@ -251,54 +287,12 @@ def process(
251
  mask = mask.resize((width, height), Image.LANCZOS)
252
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
253
 
254
-
255
- # generated control_
256
- with calculateDuration("Preprocessor Image"):
257
- print("start to generate control image")
258
- preprocessor = Preprocessor()
259
- if control_mode == "depth_midas":
260
- preprocessor.load("Midas")
261
- control_image = preprocessor(
262
- image=image,
263
- image_resolution=width,
264
- detect_resolution=512,
265
- )
266
- if control_mode == "openpose":
267
- preprocessor.load("Openpose")
268
- control_image = preprocessor(
269
- image=image,
270
- hand_and_face=False,
271
- image_resolution=width,
272
- detect_resolution=512,
273
- )
274
- if control_mode == "canny":
275
- preprocessor.load("Canny")
276
- control_image = preprocessor(
277
- image=image,
278
- image_resolution=width,
279
- detect_resolution=512,
280
- )
281
-
282
- if control_mode == "mlsd":
283
- preprocessor.load("MLSD")
284
- control_image = preprocessor(
285
- image=image_before,
286
- image_resolution=width,
287
- detect_resolution=512,
288
- )
289
-
290
-
291
- if control_mode == "scribble_hed":
292
- preprocessor.load("HED")
293
- control_image = preprocessor(
294
- image=image_before,
295
- image_resolution=image_resolution,
296
- detect_resolution=preprocess_resolution,
297
- )
298
-
299
- control_image = control_image.resize((width, height), Image.LANCZOS)
300
  control_mode_id = control_mode_ids[control_mode]
301
  clear_cuda_cache()
 
 
 
302
  try:
303
  generated_image = run_flux(
304
  image=image,
@@ -306,9 +300,6 @@ def process(
306
  control_image=control_image,
307
  control_mode=control_mode_id,
308
  prompt=inpainting_prompt_text,
309
- lora_path=lora_path,
310
- lora_scale=lora_scale,
311
- lora_weights=lora_weights,
312
  seed_slicer=seed_slicer,
313
  randomize_seed_checkbox=randomize_seed_checkbox,
314
  strength_slider=strength_slider,
@@ -321,16 +312,16 @@ def process(
321
  result["message"] = "generate image failed"
322
  print(e)
323
  generated_image = None
 
324
  clear_cuda_cache()
325
  print("run flux finish")
326
  if generated_image:
327
  if upload_to_r2:
328
- with calculateDuration("upload image"):
329
- url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
330
- result = {"status": "success", "message": "upload image success", "url": url}
331
  else:
332
  result = {"status": "success", "message": "Image generated but not uploaded"}
333
-
334
  clear_cuda_cache()
335
  final_images = []
336
  final_images.append(image)
@@ -344,7 +335,7 @@ def process(
344
 
345
 
346
  with gr.Blocks() as demo:
347
- gr.Markdown("Flux inpaint with lora")
348
  with gr.Row():
349
  with gr.Column():
350
 
@@ -367,41 +358,18 @@ with gr.Blocks() as demo:
367
  inpainting_prompt_text_component = gr.Text(
368
  label="Inpainting prompt",
369
  show_label=True,
370
- max_lines=1,
371
  placeholder="Enter text to generate inpainting",
372
  container=False,
373
  )
374
 
375
  control_mode = gr.Dropdown(
376
- [ "canny", "depth_midas", "openpose", "mlsd"], label="Controlnet Model", info="choose controlnet model!", value="openpose"
377
  )
 
378
 
379
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
380
 
381
- with gr.Accordion("Lora Settings", open=True):
382
- lora_path = gr.Textbox(
383
- label="Lora model path",
384
- show_label=True,
385
- max_lines=1,
386
- placeholder="Enter your model path",
387
- info="Currently, only LoRA hosted on Hugging Face'model can be loaded properly.",
388
- value=""
389
- )
390
- lora_weights = gr.Textbox(
391
- label="Lora weights",
392
- show_label=True,
393
- max_lines=1,
394
- placeholder="Enter your lora weights name",
395
- value=""
396
- )
397
- lora_scale = gr.Slider(
398
- label="Lora scale",
399
- show_label=True,
400
- minimum=0,
401
- maximum=1,
402
- step=0.1,
403
- value=0.9,
404
- )
405
 
406
  with gr.Accordion("Advanced Settings", open=False):
407
 
@@ -487,9 +455,7 @@ with gr.Blocks() as demo:
487
  randomize_seed_checkbox_component,
488
  strength_slider_component,
489
  num_inference_steps_slider_component,
490
- lora_path,
491
- lora_weights,
492
- lora_scale,
493
  upload_to_r2,
494
  account_id,
495
  access_key,
 
22
 
23
  import json
24
  from preprocessor import Preprocessor
25
+ from diffusers import FluxControlNetInpaintPipeline
26
+ from diffusers.models import FluxControlNetModel
 
27
 
28
  HF_TOKEN = os.environ.get("HF_TOKEN")
29
 
30
  login(token=HF_TOKEN)
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
+ IMAGE_SIZE = 512
34
 
35
  # init
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  base_model = "black-forest-labs/FLUX.1-dev"
38
 
39
+ controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
40
  controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
 
41
 
42
 
43
  pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
44
+
 
 
45
  # pipe.enable_model_cpu_offload() # for saving memory
46
 
47
  control_mode_ids = {
 
48
  "canny": 0, # supported
 
49
  "tile": 1, # supported
50
+ "depth": 2, # supported
51
  "blur": 3, # supported
52
+ "pose": 4, # supported
53
  "gray": 5, # supported
54
+ "lq": 6, # supported
55
  }
56
 
57
  def clear_cuda_cache():
 
120
  return mask
121
 
122
  def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
123
+ with calculateDuration("Upload image"):
124
+ print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
125
+ connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
126
+
127
+ s3 = boto3.client(
128
+ 's3',
129
+ endpoint_url=connectionUrl,
130
+ region_name='auto',
131
+ aws_access_key_id=access_key,
132
+ aws_secret_access_key=secret_key
133
+ )
134
 
135
+ current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
136
+ image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
137
+ buffer = BytesIO()
138
+ image.save(buffer, "PNG")
139
+ buffer.seek(0)
140
+ s3.upload_fileobj(buffer, bucket_name, image_file)
141
+ print("upload finish", image_file)
142
+
143
  return image_file
144
 
145
+ @spaces.GPU(duration=120)
146
+ @torch.inference_mode()
147
  def run_flux(
148
  image: Image.Image,
149
  mask: Image.Image,
150
  control_image: Image.Image,
151
  control_mode: int,
152
  prompt: str,
 
 
 
153
  seed_slicer: int,
154
  randomize_seed_checkbox: bool,
155
  strength_slider: float,
 
158
  progress
159
  ) -> Image.Image:
160
  print("Running FLUX...")
 
 
 
 
 
 
161
  width, height = resolution_wh
162
  if randomize_seed_checkbox:
163
  seed_slicer = random.randint(0, MAX_SEED)
 
171
  prompt=prompt,
172
  image=image,
173
  mask_image=mask,
174
+ control_image=control_image,
175
+ control_mode=control_mode,
176
  controlnet_conditioning_scale=[0.55],
177
  width=width,
178
  height=height,
179
  strength=strength_slider,
180
  generator=generator,
181
  num_inference_steps=num_inference_steps_slider,
 
 
182
  ).images[0]
183
  progress(99, "Generate image success!")
184
  return generated_image
185
 
186
+
187
+ def load_loras(lora_strings_json:str):
188
+ if lora_strings_json:
189
+ try:
190
+ lora_configs = json.loads(lora_strings_json)
191
+ except:
192
+ lora_configs = None
193
+ if lora_configs:
194
+ with calculateDuration("Loading LoRA weights"):
195
+ pipe.unload_lora_weights()
196
+ adapter_names = []
197
+ adapter_weights = []
198
+ for lora_info in lora_configs:
199
+ lora_repo = lora_info.get("repo")
200
+ weights = lora_info.get("weights")
201
+ adapter_name = lora_info.get("adapter_name")
202
+ adapter_weight = lora_info.get("adapter_weight")
203
+ if lora_repo and weights and adapter_name:
204
+ # load lora
205
+ pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
206
+ adapter_names.append(adapter_name)
207
+ adapter_weights.append(adapter_weight)
208
+ # set lora weights
209
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
210
+
211
+
212
+ def generate_control_image(orginal_image, mask, control_mode):
213
+ # generated control_
214
+ with calculateDuration("Generate control image"):
215
+ preprocessor = Preprocessor()
216
+ if control_mode == "depth":
217
+ preprocessor.load("Midas")
218
+ control_image = preprocessor(
219
+ image=image,
220
+ image_resolution=width,
221
+ detect_resolution=512,
222
+ )
223
+ if control_mode == "pose":
224
+ preprocessor.load("Openpose")
225
+ control_image = preprocessor(
226
+ image=image,
227
+ hand_and_face=False,
228
+ image_resolution=width,
229
+ detect_resolution=512,
230
+ )
231
+ if control_mode == "canny":
232
+ preprocessor.load("Canny")
233
+ control_image = preprocessor(
234
+ image=image,
235
+ image_resolution=width,
236
+ detect_resolution=512,
237
+ )
238
+
239
+ control_image = control_image.resize((width, height), Image.LANCZOS)
240
+ return control_image
241
 
242
  def process(
243
  image_url: str,
 
250
  randomize_seed_checkbox: bool,
251
  strength_slider: float,
252
  num_inference_steps_slider: int,
253
+ lora_strings_json: str,
 
 
254
  upload_to_r2: bool,
255
  account_id: str,
256
  access_key: str,
 
287
  mask = mask.resize((width, height), Image.LANCZOS)
288
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
289
 
290
+ control_image = generate_control_image(image, mask, control_mode)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  control_mode_id = control_mode_ids[control_mode]
292
  clear_cuda_cache()
293
+
294
+ load_loras(lora_strings_json=lora_strings_json)
295
+
296
  try:
297
  generated_image = run_flux(
298
  image=image,
 
300
  control_image=control_image,
301
  control_mode=control_mode_id,
302
  prompt=inpainting_prompt_text,
 
 
 
303
  seed_slicer=seed_slicer,
304
  randomize_seed_checkbox=randomize_seed_checkbox,
305
  strength_slider=strength_slider,
 
312
  result["message"] = "generate image failed"
313
  print(e)
314
  generated_image = None
315
+
316
  clear_cuda_cache()
317
  print("run flux finish")
318
  if generated_image:
319
  if upload_to_r2:
320
+ url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
321
+ result = {"status": "success", "message": "upload image success", "url": url}
 
322
  else:
323
  result = {"status": "success", "message": "Image generated but not uploaded"}
324
+
325
  clear_cuda_cache()
326
  final_images = []
327
  final_images.append(image)
 
335
 
336
 
337
  with gr.Blocks() as demo:
338
+ gr.Markdown("Flux controlnet inpaint with loras")
339
  with gr.Row():
340
  with gr.Column():
341
 
 
358
  inpainting_prompt_text_component = gr.Text(
359
  label="Inpainting prompt",
360
  show_label=True,
361
+ max_lines=5,
362
  placeholder="Enter text to generate inpainting",
363
  container=False,
364
  )
365
 
366
  control_mode = gr.Dropdown(
367
+ [ "canny", "depth", "pose"], label="Controlnet Model", info="choose controlnet model!", value="canny"
368
  )
369
+ lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
370
 
371
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  with gr.Accordion("Advanced Settings", open=False):
375
 
 
455
  randomize_seed_checkbox_component,
456
  strength_slider_component,
457
  num_inference_steps_slider_component,
458
+ lora_strings_json,
 
 
459
  upload_to_r2,
460
  account_id,
461
  access_key,
requirements.txt CHANGED
@@ -7,7 +7,7 @@ einops
7
  spaces
8
  gradio
9
  opencv-python
10
- git+https://github.com/mylovelycodes/diffusers.git
11
  boto3
12
  sentencepiece
13
  peft
 
7
  spaces
8
  gradio
9
  opencv-python
10
+ git+https://github.com/diffusers/diffusers.git
11
  boto3
12
  sentencepiece
13
  peft