Spaces:
Running
on
Zero
Running
on
Zero
bugfix
Browse files- README.md +1 -1
- app.py +100 -134
- requirements.txt +1 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Flux
|
3 |
emoji: 🖼
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: Flux-dev controlnet inpainting with lora
|
3 |
emoji: 🖼
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
app.py
CHANGED
@@ -22,42 +22,36 @@ from diffusers.utils import load_image, make_image_grid
|
|
22 |
|
23 |
import json
|
24 |
from preprocessor import Preprocessor
|
25 |
-
from diffusers
|
26 |
-
from diffusers.models
|
27 |
-
from diffusers.models import FluxMultiControlNetModel
|
28 |
|
29 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
30 |
|
31 |
login(token=HF_TOKEN)
|
32 |
|
33 |
MAX_SEED = np.iinfo(np.int32).max
|
34 |
-
IMAGE_SIZE =
|
35 |
|
36 |
# init
|
37 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
38 |
base_model = "black-forest-labs/FLUX.1-dev"
|
39 |
|
40 |
-
controlnet_model = '
|
41 |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
42 |
-
controlnet = FluxMultiControlNetModel([controlnet])
|
43 |
|
44 |
|
45 |
pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
|
46 |
-
|
47 |
-
pipe.vae.enable_tiling()
|
48 |
-
pipe.vae.enable_slicing()
|
49 |
# pipe.enable_model_cpu_offload() # for saving memory
|
50 |
|
51 |
control_mode_ids = {
|
52 |
-
"scribble_hed": 0,
|
53 |
"canny": 0, # supported
|
54 |
-
"mlsd": 0, # supported
|
55 |
"tile": 1, # supported
|
56 |
-
"
|
57 |
"blur": 3, # supported
|
58 |
-
"
|
59 |
"gray": 5, # supported
|
60 |
-
"
|
61 |
}
|
62 |
|
63 |
def clear_cuda_cache():
|
@@ -126,37 +120,36 @@ def process_mask(
|
|
126 |
return mask
|
127 |
|
128 |
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
return image_file
|
149 |
|
150 |
-
|
|
|
151 |
def run_flux(
|
152 |
image: Image.Image,
|
153 |
mask: Image.Image,
|
154 |
control_image: Image.Image,
|
155 |
control_mode: int,
|
156 |
prompt: str,
|
157 |
-
lora_path: str,
|
158 |
-
lora_weights: str,
|
159 |
-
lora_scale: float,
|
160 |
seed_slicer: int,
|
161 |
randomize_seed_checkbox: bool,
|
162 |
strength_slider: float,
|
@@ -165,12 +158,6 @@ def run_flux(
|
|
165 |
progress
|
166 |
) -> Image.Image:
|
167 |
print("Running FLUX...")
|
168 |
-
clear_cuda_cache()
|
169 |
-
if lora_path and lora_weights:
|
170 |
-
with calculateDuration("load lora"):
|
171 |
-
print("start to load lora", lora_path, lora_weights)
|
172 |
-
pipe.load_lora_weights(lora_path, weight_name=lora_weights)
|
173 |
-
|
174 |
width, height = resolution_wh
|
175 |
if randomize_seed_checkbox:
|
176 |
seed_slicer = random.randint(0, MAX_SEED)
|
@@ -184,22 +171,73 @@ def run_flux(
|
|
184 |
prompt=prompt,
|
185 |
image=image,
|
186 |
mask_image=mask,
|
187 |
-
control_image=
|
188 |
-
control_mode=
|
189 |
controlnet_conditioning_scale=[0.55],
|
190 |
width=width,
|
191 |
height=height,
|
192 |
strength=strength_slider,
|
193 |
generator=generator,
|
194 |
num_inference_steps=num_inference_steps_slider,
|
195 |
-
# max_sequence_length=256,
|
196 |
-
joint_attention_kwargs={"scale": lora_scale}
|
197 |
).images[0]
|
198 |
progress(99, "Generate image success!")
|
199 |
return generated_image
|
200 |
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
def process(
|
205 |
image_url: str,
|
@@ -212,9 +250,7 @@ def process(
|
|
212 |
randomize_seed_checkbox: bool,
|
213 |
strength_slider: float,
|
214 |
num_inference_steps_slider: int,
|
215 |
-
|
216 |
-
lora_weights: str,
|
217 |
-
lora_scale: str,
|
218 |
upload_to_r2: bool,
|
219 |
account_id: str,
|
220 |
access_key: str,
|
@@ -251,54 +287,12 @@ def process(
|
|
251 |
mask = mask.resize((width, height), Image.LANCZOS)
|
252 |
mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
|
253 |
|
254 |
-
|
255 |
-
# generated control_
|
256 |
-
with calculateDuration("Preprocessor Image"):
|
257 |
-
print("start to generate control image")
|
258 |
-
preprocessor = Preprocessor()
|
259 |
-
if control_mode == "depth_midas":
|
260 |
-
preprocessor.load("Midas")
|
261 |
-
control_image = preprocessor(
|
262 |
-
image=image,
|
263 |
-
image_resolution=width,
|
264 |
-
detect_resolution=512,
|
265 |
-
)
|
266 |
-
if control_mode == "openpose":
|
267 |
-
preprocessor.load("Openpose")
|
268 |
-
control_image = preprocessor(
|
269 |
-
image=image,
|
270 |
-
hand_and_face=False,
|
271 |
-
image_resolution=width,
|
272 |
-
detect_resolution=512,
|
273 |
-
)
|
274 |
-
if control_mode == "canny":
|
275 |
-
preprocessor.load("Canny")
|
276 |
-
control_image = preprocessor(
|
277 |
-
image=image,
|
278 |
-
image_resolution=width,
|
279 |
-
detect_resolution=512,
|
280 |
-
)
|
281 |
-
|
282 |
-
if control_mode == "mlsd":
|
283 |
-
preprocessor.load("MLSD")
|
284 |
-
control_image = preprocessor(
|
285 |
-
image=image_before,
|
286 |
-
image_resolution=width,
|
287 |
-
detect_resolution=512,
|
288 |
-
)
|
289 |
-
|
290 |
-
|
291 |
-
if control_mode == "scribble_hed":
|
292 |
-
preprocessor.load("HED")
|
293 |
-
control_image = preprocessor(
|
294 |
-
image=image_before,
|
295 |
-
image_resolution=image_resolution,
|
296 |
-
detect_resolution=preprocess_resolution,
|
297 |
-
)
|
298 |
-
|
299 |
-
control_image = control_image.resize((width, height), Image.LANCZOS)
|
300 |
control_mode_id = control_mode_ids[control_mode]
|
301 |
clear_cuda_cache()
|
|
|
|
|
|
|
302 |
try:
|
303 |
generated_image = run_flux(
|
304 |
image=image,
|
@@ -306,9 +300,6 @@ def process(
|
|
306 |
control_image=control_image,
|
307 |
control_mode=control_mode_id,
|
308 |
prompt=inpainting_prompt_text,
|
309 |
-
lora_path=lora_path,
|
310 |
-
lora_scale=lora_scale,
|
311 |
-
lora_weights=lora_weights,
|
312 |
seed_slicer=seed_slicer,
|
313 |
randomize_seed_checkbox=randomize_seed_checkbox,
|
314 |
strength_slider=strength_slider,
|
@@ -321,16 +312,16 @@ def process(
|
|
321 |
result["message"] = "generate image failed"
|
322 |
print(e)
|
323 |
generated_image = None
|
|
|
324 |
clear_cuda_cache()
|
325 |
print("run flux finish")
|
326 |
if generated_image:
|
327 |
if upload_to_r2:
|
328 |
-
|
329 |
-
|
330 |
-
result = {"status": "success", "message": "upload image success", "url": url}
|
331 |
else:
|
332 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
333 |
-
|
334 |
clear_cuda_cache()
|
335 |
final_images = []
|
336 |
final_images.append(image)
|
@@ -344,7 +335,7 @@ def process(
|
|
344 |
|
345 |
|
346 |
with gr.Blocks() as demo:
|
347 |
-
gr.Markdown("Flux inpaint with
|
348 |
with gr.Row():
|
349 |
with gr.Column():
|
350 |
|
@@ -367,41 +358,18 @@ with gr.Blocks() as demo:
|
|
367 |
inpainting_prompt_text_component = gr.Text(
|
368 |
label="Inpainting prompt",
|
369 |
show_label=True,
|
370 |
-
max_lines=
|
371 |
placeholder="Enter text to generate inpainting",
|
372 |
container=False,
|
373 |
)
|
374 |
|
375 |
control_mode = gr.Dropdown(
|
376 |
-
[ "canny", "
|
377 |
)
|
|
|
378 |
|
379 |
submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
|
380 |
|
381 |
-
with gr.Accordion("Lora Settings", open=True):
|
382 |
-
lora_path = gr.Textbox(
|
383 |
-
label="Lora model path",
|
384 |
-
show_label=True,
|
385 |
-
max_lines=1,
|
386 |
-
placeholder="Enter your model path",
|
387 |
-
info="Currently, only LoRA hosted on Hugging Face'model can be loaded properly.",
|
388 |
-
value=""
|
389 |
-
)
|
390 |
-
lora_weights = gr.Textbox(
|
391 |
-
label="Lora weights",
|
392 |
-
show_label=True,
|
393 |
-
max_lines=1,
|
394 |
-
placeholder="Enter your lora weights name",
|
395 |
-
value=""
|
396 |
-
)
|
397 |
-
lora_scale = gr.Slider(
|
398 |
-
label="Lora scale",
|
399 |
-
show_label=True,
|
400 |
-
minimum=0,
|
401 |
-
maximum=1,
|
402 |
-
step=0.1,
|
403 |
-
value=0.9,
|
404 |
-
)
|
405 |
|
406 |
with gr.Accordion("Advanced Settings", open=False):
|
407 |
|
@@ -487,9 +455,7 @@ with gr.Blocks() as demo:
|
|
487 |
randomize_seed_checkbox_component,
|
488 |
strength_slider_component,
|
489 |
num_inference_steps_slider_component,
|
490 |
-
|
491 |
-
lora_weights,
|
492 |
-
lora_scale,
|
493 |
upload_to_r2,
|
494 |
account_id,
|
495 |
access_key,
|
|
|
22 |
|
23 |
import json
|
24 |
from preprocessor import Preprocessor
|
25 |
+
from diffusers import FluxControlNetInpaintPipeline
|
26 |
+
from diffusers.models import FluxControlNetModel
|
|
|
27 |
|
28 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
29 |
|
30 |
login(token=HF_TOKEN)
|
31 |
|
32 |
MAX_SEED = np.iinfo(np.int32).max
|
33 |
+
IMAGE_SIZE = 512
|
34 |
|
35 |
# init
|
36 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
base_model = "black-forest-labs/FLUX.1-dev"
|
38 |
|
39 |
+
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
|
40 |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
|
|
41 |
|
42 |
|
43 |
pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
|
44 |
+
|
|
|
|
|
45 |
# pipe.enable_model_cpu_offload() # for saving memory
|
46 |
|
47 |
control_mode_ids = {
|
|
|
48 |
"canny": 0, # supported
|
|
|
49 |
"tile": 1, # supported
|
50 |
+
"depth": 2, # supported
|
51 |
"blur": 3, # supported
|
52 |
+
"pose": 4, # supported
|
53 |
"gray": 5, # supported
|
54 |
+
"lq": 6, # supported
|
55 |
}
|
56 |
|
57 |
def clear_cuda_cache():
|
|
|
120 |
return mask
|
121 |
|
122 |
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
123 |
+
with calculateDuration("Upload image"):
|
124 |
+
print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
|
125 |
+
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
126 |
+
|
127 |
+
s3 = boto3.client(
|
128 |
+
's3',
|
129 |
+
endpoint_url=connectionUrl,
|
130 |
+
region_name='auto',
|
131 |
+
aws_access_key_id=access_key,
|
132 |
+
aws_secret_access_key=secret_key
|
133 |
+
)
|
134 |
|
135 |
+
current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
|
136 |
+
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
|
137 |
+
buffer = BytesIO()
|
138 |
+
image.save(buffer, "PNG")
|
139 |
+
buffer.seek(0)
|
140 |
+
s3.upload_fileobj(buffer, bucket_name, image_file)
|
141 |
+
print("upload finish", image_file)
|
142 |
+
|
143 |
return image_file
|
144 |
|
145 |
+
@spaces.GPU(duration=120)
|
146 |
+
@torch.inference_mode()
|
147 |
def run_flux(
|
148 |
image: Image.Image,
|
149 |
mask: Image.Image,
|
150 |
control_image: Image.Image,
|
151 |
control_mode: int,
|
152 |
prompt: str,
|
|
|
|
|
|
|
153 |
seed_slicer: int,
|
154 |
randomize_seed_checkbox: bool,
|
155 |
strength_slider: float,
|
|
|
158 |
progress
|
159 |
) -> Image.Image:
|
160 |
print("Running FLUX...")
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
width, height = resolution_wh
|
162 |
if randomize_seed_checkbox:
|
163 |
seed_slicer = random.randint(0, MAX_SEED)
|
|
|
171 |
prompt=prompt,
|
172 |
image=image,
|
173 |
mask_image=mask,
|
174 |
+
control_image=control_image,
|
175 |
+
control_mode=control_mode,
|
176 |
controlnet_conditioning_scale=[0.55],
|
177 |
width=width,
|
178 |
height=height,
|
179 |
strength=strength_slider,
|
180 |
generator=generator,
|
181 |
num_inference_steps=num_inference_steps_slider,
|
|
|
|
|
182 |
).images[0]
|
183 |
progress(99, "Generate image success!")
|
184 |
return generated_image
|
185 |
|
186 |
+
|
187 |
+
def load_loras(lora_strings_json:str):
|
188 |
+
if lora_strings_json:
|
189 |
+
try:
|
190 |
+
lora_configs = json.loads(lora_strings_json)
|
191 |
+
except:
|
192 |
+
lora_configs = None
|
193 |
+
if lora_configs:
|
194 |
+
with calculateDuration("Loading LoRA weights"):
|
195 |
+
pipe.unload_lora_weights()
|
196 |
+
adapter_names = []
|
197 |
+
adapter_weights = []
|
198 |
+
for lora_info in lora_configs:
|
199 |
+
lora_repo = lora_info.get("repo")
|
200 |
+
weights = lora_info.get("weights")
|
201 |
+
adapter_name = lora_info.get("adapter_name")
|
202 |
+
adapter_weight = lora_info.get("adapter_weight")
|
203 |
+
if lora_repo and weights and adapter_name:
|
204 |
+
# load lora
|
205 |
+
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
206 |
+
adapter_names.append(adapter_name)
|
207 |
+
adapter_weights.append(adapter_weight)
|
208 |
+
# set lora weights
|
209 |
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
210 |
+
|
211 |
+
|
212 |
+
def generate_control_image(orginal_image, mask, control_mode):
|
213 |
+
# generated control_
|
214 |
+
with calculateDuration("Generate control image"):
|
215 |
+
preprocessor = Preprocessor()
|
216 |
+
if control_mode == "depth":
|
217 |
+
preprocessor.load("Midas")
|
218 |
+
control_image = preprocessor(
|
219 |
+
image=image,
|
220 |
+
image_resolution=width,
|
221 |
+
detect_resolution=512,
|
222 |
+
)
|
223 |
+
if control_mode == "pose":
|
224 |
+
preprocessor.load("Openpose")
|
225 |
+
control_image = preprocessor(
|
226 |
+
image=image,
|
227 |
+
hand_and_face=False,
|
228 |
+
image_resolution=width,
|
229 |
+
detect_resolution=512,
|
230 |
+
)
|
231 |
+
if control_mode == "canny":
|
232 |
+
preprocessor.load("Canny")
|
233 |
+
control_image = preprocessor(
|
234 |
+
image=image,
|
235 |
+
image_resolution=width,
|
236 |
+
detect_resolution=512,
|
237 |
+
)
|
238 |
+
|
239 |
+
control_image = control_image.resize((width, height), Image.LANCZOS)
|
240 |
+
return control_image
|
241 |
|
242 |
def process(
|
243 |
image_url: str,
|
|
|
250 |
randomize_seed_checkbox: bool,
|
251 |
strength_slider: float,
|
252 |
num_inference_steps_slider: int,
|
253 |
+
lora_strings_json: str,
|
|
|
|
|
254 |
upload_to_r2: bool,
|
255 |
account_id: str,
|
256 |
access_key: str,
|
|
|
287 |
mask = mask.resize((width, height), Image.LANCZOS)
|
288 |
mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
|
289 |
|
290 |
+
control_image = generate_control_image(image, mask, control_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
control_mode_id = control_mode_ids[control_mode]
|
292 |
clear_cuda_cache()
|
293 |
+
|
294 |
+
load_loras(lora_strings_json=lora_strings_json)
|
295 |
+
|
296 |
try:
|
297 |
generated_image = run_flux(
|
298 |
image=image,
|
|
|
300 |
control_image=control_image,
|
301 |
control_mode=control_mode_id,
|
302 |
prompt=inpainting_prompt_text,
|
|
|
|
|
|
|
303 |
seed_slicer=seed_slicer,
|
304 |
randomize_seed_checkbox=randomize_seed_checkbox,
|
305 |
strength_slider=strength_slider,
|
|
|
312 |
result["message"] = "generate image failed"
|
313 |
print(e)
|
314 |
generated_image = None
|
315 |
+
|
316 |
clear_cuda_cache()
|
317 |
print("run flux finish")
|
318 |
if generated_image:
|
319 |
if upload_to_r2:
|
320 |
+
url = upload_image_to_r2(generated_image, account_id, access_key, secret_key, bucket)
|
321 |
+
result = {"status": "success", "message": "upload image success", "url": url}
|
|
|
322 |
else:
|
323 |
result = {"status": "success", "message": "Image generated but not uploaded"}
|
324 |
+
|
325 |
clear_cuda_cache()
|
326 |
final_images = []
|
327 |
final_images.append(image)
|
|
|
335 |
|
336 |
|
337 |
with gr.Blocks() as demo:
|
338 |
+
gr.Markdown("Flux controlnet inpaint with loras")
|
339 |
with gr.Row():
|
340 |
with gr.Column():
|
341 |
|
|
|
358 |
inpainting_prompt_text_component = gr.Text(
|
359 |
label="Inpainting prompt",
|
360 |
show_label=True,
|
361 |
+
max_lines=5,
|
362 |
placeholder="Enter text to generate inpainting",
|
363 |
container=False,
|
364 |
)
|
365 |
|
366 |
control_mode = gr.Dropdown(
|
367 |
+
[ "canny", "depth", "pose"], label="Controlnet Model", info="choose controlnet model!", value="canny"
|
368 |
)
|
369 |
+
lora_strings_json = gr.Text(label="LoRA Configs (JSON List String)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1", "adapter_weight": 1}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2", "adapter_weight": 1}]', lines=5)
|
370 |
|
371 |
submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
with gr.Accordion("Advanced Settings", open=False):
|
375 |
|
|
|
455 |
randomize_seed_checkbox_component,
|
456 |
strength_slider_component,
|
457 |
num_inference_steps_slider_component,
|
458 |
+
lora_strings_json,
|
|
|
|
|
459 |
upload_to_r2,
|
460 |
account_id,
|
461 |
access_key,
|
requirements.txt
CHANGED
@@ -7,7 +7,7 @@ einops
|
|
7 |
spaces
|
8 |
gradio
|
9 |
opencv-python
|
10 |
-
git+https://github.com/
|
11 |
boto3
|
12 |
sentencepiece
|
13 |
peft
|
|
|
7 |
spaces
|
8 |
gradio
|
9 |
opencv-python
|
10 |
+
git+https://github.com/diffusers/diffusers.git
|
11 |
boto3
|
12 |
sentencepiece
|
13 |
peft
|