Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import os | |
from src.flux.xflux_pipeline import XFluxPipeline | |
import random | |
import spaces | |
def run_xflux_pipeline( | |
prompt, image, repo_id, name, device, | |
model_type, width, height, timestep_to_start_cfg, num_steps, true_gs, guidance, | |
neg_prompt="", | |
negative_image=None, | |
save_path='results', control_type='depth', use_controlnet=False, seed=None, num_images_per_prompt=1, use_lora=False, lora_weight=0.7, lora_repo_id="XLabs-AI/flux-lora-collection", lora_name="realism_lora.safetensors", use_ip=False | |
): | |
# Montando os argumentos simulando a linha de comando | |
class Args: | |
def __init__(self): | |
self.prompt = prompt | |
self.image = image | |
self.control_type = control_type | |
self.repo_id = repo_id | |
self.name = name | |
self.device = device | |
self.use_controlnet = use_controlnet | |
self.model_type = model_type | |
self.width = width | |
self.height = height | |
self.timestep_to_start_cfg = timestep_to_start_cfg | |
self.num_steps = num_steps | |
self.true_gs = true_gs | |
self.guidance = guidance | |
self.num_images_per_prompt = num_images_per_prompt | |
self.seed = seed if seed else 123456789 | |
self.neg_prompt = neg_prompt | |
self.img_prompt = Image.open(image) | |
self.neg_img_prompt = Image.open(negative_image) if negative_image else None | |
self.ip_scale = 1.0 | |
self.neg_ip_scale = 1.0 | |
self.local_path = None | |
self.ip_repo_id = "XLabs-AI/flux-ip-adapter" | |
self.ip_name = "flux-ip-adapter.safetensors" | |
self.ip_local_path = None | |
self.lora_repo_id = lora_repo_id | |
self.lora_name = lora_name | |
self.lora_local_path = None | |
self.offload = False | |
self.use_ip = use_ip | |
self.use_lora = use_lora | |
self.lora_weight = lora_weight | |
self.save_path = save_path | |
args = Args() | |
# Carregar a imagem se fornecida | |
if args.image: | |
image = Image.open(args.image) | |
else: | |
image = None | |
# Inicializar o pipeline com os parâmetros necessários | |
xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload) | |
# Configurar ControlNet se necessário | |
if args.use_controlnet: | |
print('Loading ControlNet:', args.local_path, args.repo_id, args.name) | |
xflux_pipeline.set_controlnet(args.control_type, args.local_path, args.repo_id, args.name) | |
if args.use_ip: | |
print('load ip-adapter:', args.ip_local_path, args.ip_repo_id, args.ip_name) | |
xflux_pipeline.set_ip(args.ip_local_path, args.ip_repo_id, args.ip_name) | |
if args.use_lora: | |
print('load lora:', args.lora_local_path, args.lora_repo_id, args.lora_name) | |
xflux_pipeline.set_lora(args.lora_local_path, args.lora_repo_id, args.lora_name, args.lora_weight) | |
# Laço para gerar imagens | |
images = [] | |
for _ in range(args.num_images_per_prompt): | |
seed = random.randint(0, 2147483647) | |
result = xflux_pipeline( | |
prompt=args.prompt, | |
controlnet_image=image, | |
width=args.width, | |
height=args.height, | |
guidance=args.guidance, | |
num_steps=args.num_steps, | |
seed=seed, | |
true_gs=args.true_gs, | |
neg_prompt=args.neg_prompt, | |
timestep_to_start_cfg=args.timestep_to_start_cfg, | |
image_prompt=args.img_prompt, | |
neg_image_prompt=args.neg_img_prompt, | |
ip_scale=args.ip_scale, | |
neg_ip_scale=args.neg_ip_scale, | |
) | |
images.append(result) | |
return images | |
def process_image(image, prompt, steps, use_lora, use_controlnet, use_depth, use_hed, use_ip, lora_name, lora_path, lora_weight, negative_image, neg_prompt, true_gs, guidance, cfg): | |
return run_xflux_pipeline( | |
prompt=prompt, | |
neg_prompt=neg_prompt, | |
image=image, | |
negative_image=negative_image, | |
lora_name=lora_name, | |
lora_weight=lora_weight, | |
lora_repo_id=lora_path, | |
control_type="depth" if use_depth else "hed" if use_hed else "canny", | |
repo_id="XLabs-AI/flux-controlnet-collections", | |
name="flux-depth-controlnet.safetensors", | |
device="cuda", | |
use_controlnet=use_controlnet, | |
model_type="flux-dev", | |
width=1024, | |
height=1024, | |
timestep_to_start_cfg=cfg, | |
num_steps=steps, | |
num_images_per_prompt=4, | |
use_lora=use_lora, | |
true_gs=true_gs, | |
use_ip=use_ip, | |
guidance=guidance | |
) | |
custom_css = """ | |
body { | |
background: rgb(24, 24, 27); | |
} | |
.gradio-container { | |
background: rgb(24, 24, 27); | |
} | |
.app-container { | |
background: rgb(24, 24, 27); | |
} | |
gradio-app { | |
background: rgb(24, 24, 27); | |
} | |
.sidebar { | |
background: rgb(31, 31, 35); | |
border-right: 1px solid rgb(41, 41, 41); | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
with gr.Row(elem_classes="app-container"): | |
with gr.Column(): | |
input_image = gr.Image(label="Image", type="filepath") | |
negative_image = gr.Image(label="Negative_image", type="filepath") | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
neg_prompt = gr.Textbox(label="Neg Prompt") | |
steps = gr.Slider(step=1, minimum=1, maximum=64, value=28, label="Num Steps") | |
use_lora = gr.Checkbox(label="Use LORA", value=True) | |
lora_path = gr.Textbox(label="LoraPath", value="XLabs-AI/flux-lora-collection") | |
lora_name = gr.Textbox(label="LoraName", value="realism_lora.safetensors") | |
lora_weight = gr.Slider(step=0.1, minimum=0, maximum=1, value=0.7, label="Lora Weight") | |
controlnet = gr.Checkbox(label="Use Controlnet(by default uses canny)", value=True) | |
use_ip = gr.Checkbox(label="Use IP") | |
use_depth = gr.Checkbox(label="Use depth") | |
use_hed = gr.Checkbox(label="Use hed") | |
true_gs = gr.Slider(step=0.1, minimum=0, maximum=10, value=3.5, label="TrueGs") | |
guidance = gr.Slider(minimum=1, maximum=10, value=4, label="Guidance") | |
cfg = gr.Slider(minimum=1, maximum=10, value=1, label="CFG") | |
with gr.Column(): | |
output = gr.Gallery(label="Galery output", elem_classes="galery", selected_index=0) | |
submit_btn.click(process_image, inputs=[input_image, prompt, steps, use_lora, controlnet, use_depth, use_hed, use_ip, lora_name, lora_path, lora_weight, negative_image, neg_prompt, true_gs, guidance, cfg], outputs=output) | |
demo.launch(share=True) |