SRPO / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
44895ee verified
import gradio as gr
import torch
import spaces
from diffusers import FluxPipeline
from safetensors.torch import load_file
# Load the model
pipe = FluxPipeline.from_pretrained(
'black-forest-labs/FLUX.1-dev',
torch_dtype=torch.bfloat16,
use_safetensors=True
).to('cuda')
# Load SRPO weights
from huggingface_hub import hf_hub_download
srpo_path = hf_hub_download(
repo_id="tencent/SRPO",
filename="diffusion_pytorch_model.safetensors"
)
state_dict = load_file(srpo_path)
pipe.transformer.load_state_dict(state_dict)
@spaces.GPU(duration=120)
def generate_image(
prompt,
width=1024,
height=1024,
guidance_scale=3.5,
num_inference_steps=50,
seed=-1
):
if seed == -1:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device='cuda').manual_seed(seed)
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=generator
).images[0]
return image, seed
with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo:
gr.Markdown("# Flux SRPO")
gr.Markdown("Generate images using FLUX model enhanced with Tencent's SRPO technique")
gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)")
output_image = gr.Image(label="Generated Image", type="pil")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=64,
label="Width"
)
height = gr.Slider(
minimum=256,
maximum=2048,
value=1024,
step=64,
label="Height"
)
with gr.Row():
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=3.5,
step=0.5,
label="Guidance Scale"
)
num_inference_steps = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=5,
label="Inference Steps"
)
seed = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
used_seed = gr.Number(label="Seed Used", precision=0)
gr.Examples(
examples=[
["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"],
["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"],
["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"],
["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"],
["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"],
],
inputs=prompt,
label="Example Prompts"
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed],
outputs=[output_image, used_seed]
)
demo.launch()