Flux-Mini / app.py
daoyuan98's picture
Update app.py
4bc8a6f verified
from dataclasses import dataclass
from typing import Union, Optional, List, Any, Dict
import gradio as gr
import numpy as np
import random
import spaces
import torch
from safetensors.torch import load_file as load_sft
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from model import Flux, FluxParams
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional = None,
device: Optional = None,
timesteps: Optional = None,
sigmas: Optional = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@torch.inference_mode()
def flux_pipe_call_that_returns_an_iterable_of_images(
self,
prompt = None,
prompt_2 = None,
height = None,
width = None,
num_inference_steps: int = 28,
timesteps = None,
guidance_scale: float = 3.5,
num_images_per_prompt = 1,
generator = None,
latents = None,
prompt_embeds = None,
pooled_prompt_embeds = None,
output_type = "pil",
return_dict = True,
joint_attention_kwargs = None,
max_sequence_length = 512,
good_vae = None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
try:
device = self._execution_device
except:
device = torch.device('cuda:0')
# 3. Encode prompt
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
self._num_timesteps = len(timesteps)
# Handle guidance
guidance = torch.full([1], guidance_scale, device=device, dtype=dtype).expand(latents.shape[0]) # if self.transformer.params.guidance_embeds else None
# print(latent_image_ids.shape, text_ids.shape, pooled_prompt_embeds.shape)
# 6. Denoising loop
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.expand(latents.shape[0]).to(dtype)
noise_pred = self.transformer(
img=latents.to(dtype).to(device),
timesteps=(timestep / 1000).to(dtype),
guidance=guidance.to(dtype).to(device),
y=pooled_prompt_embeds.to(dtype).to(device),
txt=prompt_embeds.to(dtype).to(device),
txt_ids=text_ids.to(dtype).to(device),
img_ids=latent_image_ids.to(dtype).to(device),
)
# Yield intermediate result
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents_for_image, return_dict=False)[0]
# yield self.image_processor.postprocess(image, output_type=output_type)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
torch.cuda.empty_cache()
# Final image using good_vae
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
image = good_vae.decode(latents, return_dict=False)[0]
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
yield self.image_processor.postprocess(image, output_type=output_type)[0]
@dataclass
class ModelSpec:
params: FluxParams
repo_id: str
repo_flow: str
repo_ae: str
repo_id_ae: str
ckpt_path: str
config = ModelSpec(
repo_id="TencentARC/flux-mini",
repo_flow="flux-mini.safetensors",
repo_id_ae="black-forest-labs/FLUX.1-dev",
repo_ae="ae.safetensors",
ckpt_path=None,
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=5,
depth_single_blocks=10,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
)
)
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True):
if (config.ckpt_path is None
and config.repo_id is not None
and config.repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors"))
else:
ckpt_path = config.ckpt_path
model = Flux(config.params)
if ckpt_path is not None:
sd = load_sft(ckpt_path, device=str(device))
missing, unexpected = model.load_state_dict(sd, strict=True)
return model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
good_vae = vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder", torch_dtype=dtype).to(device)
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer")
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=dtype).to(device)
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2")
transformer = load_flow_model2(config, device).to(dtype).to(device)
pipe = FluxPipeline(
scheduler,
vae,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
transformer
)
torch.cuda.empty_cache()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
@spaces.GPU(duration=30)
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
torch.cuda.empty_cache()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
output_type="pil",
good_vae=good_vae,
):
pass
return img, seed
examples = [
"a lovely cat",
"thousands of luminous oysters on a shore reflecting and refracting the sunset",
"profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta,"
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX-Mini
A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://blackforestlabs.ai/)
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.Examples(
examples = examples,
fn = infer,
inputs = [prompt],
outputs = [result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = infer,
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs = [result, seed]
)
demo.launch()