XFluxSpace / app.py
stazizov's picture
Update app.py
4e1e761 verified
raw
history blame
9.51 kB
import spaces
import torch
# torch.multiprocessing.set_start_method('spawn', force=True)
import re
import os
import yaml
import tempfile
import subprocess
from pathlib import Path
from dataclasses import dataclass
import gradio as gr
from src.flux.xflux_pipeline import XFluxPipeline
import os
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("No Hugging Face token found.")
@dataclass
class Config:
name: str = "flux-dev"
device: str = "cpu"
offload: bool = False
share: bool = False
ckpt_dir: str = "."
xflux_pipeline = XFluxPipeline(Config.name, Config.device, Config.offload)
xflux_pipeline.to(device='cuda' if torch.cuda.is_available() else 'cpu')
@spaces.GPU
def generate(**kwargs):
return xflux_pipeline.gradio_generate(**kwargs)
def parse_args() -> Config:
parser = argparse.ArgumentParser(description="Flux")
parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
args = parser.parse_args()
return Config(**vars(args))
def list_dirs(path):
if path is None or path == "None" or path == "":
return
if not os.path.exists(path):
path = os.path.dirname(path)
if not os.path.exists(path):
return
if not os.path.isdir(path):
path = os.path.dirname(path)
def natural_sort_key(s, regex=re.compile("([0-9]+)")):
return [
int(text) if text.isdigit() else text.lower() for text in regex.split(s)
]
subdirs = [
(item, os.path.join(path, item))
for item in os.listdir(path)
if os.path.isdir(os.path.join(path, item))
]
subdirs = [
filename
for item, filename in subdirs
if item[0] != "." and item not in ["__pycache__"]
]
subdirs = sorted(subdirs, key=natural_sort_key)
if os.path.dirname(path) != "":
dirs = [os.path.dirname(path), path] + subdirs
else:
dirs = [path] + subdirs
if os.sep == "\\":
dirs = [d.replace("\\", "/") for d in dirs]
for d in dirs:
yield d
def list_train_data_dirs():
current_train_data_dir = "."
return list(list_dirs(current_train_data_dir))
def update_config(d, u):
for k, v in u.items():
if isinstance(v, dict):
d[k] = update_config(d.get(k, {}), v)
else:
# convert Gradio components to strings
if hasattr(v, 'value'):
d[k] = str(v.value)
else:
try:
d[k] = int(v)
except (TypeError, ValueError):
d[k] = str(v)
return d
def start_lora_training(
data_dir: str, output_dir: str, lr: float, steps: int, rank: int
):
inputs = {
"data_config": {
"img_dir": data_dir,
},
"output_dir": output_dir,
"learning_rate": lr,
"rank": rank,
"max_train_steps": steps,
}
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Creating folder {output_dir} for the output checkpoint file...")
script_path = Path(__file__).resolve()
config_path = script_path.parent / "train_configs" / "test_lora.yaml"
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
config = update_config(config, inputs)
print("Config file is updated...", config)
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file:
yaml.dump(config, temp_file, default_flow_style=False)
tmp_config_path = temp_file.name
command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path]
result = subprocess.run(command, check=True)
# rRemove the temporary file after the command is run
Path(tmp_config_path).unlink()
return result
def create_demo(
model_type: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
ckpt_dir: str = "",
):
checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
with gr.Blocks() as demo:
gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}")
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
with gr.Accordion("Generation Options", open=False):
with gr.Row():
width = gr.Slider(512, 2048, 1024, step=16, label="Width")
height = gr.Slider(512, 2048, 1024, step=16, label="Height")
neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo")
with gr.Row():
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg")
with gr.Row():
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
with gr.Accordion("ControlNet Options", open=False):
control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type")
control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True)
local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint",
info="Local Path to Controlnet weights (if no, it will be downloaded from HF)"
)
controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True)
with gr.Accordion("LoRA Options", open=False):
lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True)
lora_local_path = gr.Dropdown(
checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights"
)
with gr.Accordion("IP Adapter Options", open=False):
image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True)
ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale")
neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True)
neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale")
ip_local_path = gr.Dropdown(
checkpoints, label="IP Adapter Checkpoint",
info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)"
)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
download_btn = gr.File(label="Download full-resolution")
inputs = [prompt, image_prompt, controlnet_image, width, height, guidance,
num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
lora_weight, local_path, lora_local_path, ip_local_path
]
generate_btn.click(
fn=generate,
inputs=inputs,
outputs=[output_image, download_btn],
)
with gr.Tab("LoRA Finetuning"):
data_dir = gr.Dropdown(list_train_data_dirs(),
label="Training images (directory containing the training images)"
)
output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint")
with gr.Accordion("Training Options", open=True):
lr = gr.Textbox(label="Learning Rate", value="1e-5")
steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps")
rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank")
training_btn = gr.Button("Start training")
training_btn.click(
fn=start_lora_training,
inputs=[data_dir, output_dir, lr, steps, rank],
outputs=[],
)
return demo
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True) # Corrected start method for CUDA
config = Config()
demo = create_demo(config.name, config.device, config.offload, config.ckpt_dir)
demo.launch(share=True)