Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
from typing import Union, Optional, List, Any, Dict | |
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
from safetensors.torch import load_file as load_sft | |
from huggingface_hub import hf_hub_download | |
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from model import Flux, FluxParams | |
def calculate_shift( | |
image_seq_len, | |
base_seq_len: int = 256, | |
max_seq_len: int = 4096, | |
base_shift: float = 0.5, | |
max_shift: float = 1.16, | |
): | |
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
b = base_shift - m * base_seq_len | |
mu = image_seq_len * m + b | |
return mu | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional = None, | |
device: Optional = None, | |
timesteps: Optional = None, | |
sigmas: Optional = None, | |
**kwargs, | |
): | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
def flux_pipe_call_that_returns_an_iterable_of_images( | |
self, | |
prompt = None, | |
prompt_2 = None, | |
height = None, | |
width = None, | |
num_inference_steps: int = 28, | |
timesteps = None, | |
guidance_scale: float = 3.5, | |
num_images_per_prompt = 1, | |
generator = None, | |
latents = None, | |
prompt_embeds = None, | |
pooled_prompt_embeds = None, | |
output_type = "pil", | |
return_dict = True, | |
joint_attention_kwargs = None, | |
max_sequence_length = 512, | |
good_vae = None, | |
): | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
# 1. Check inputs | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
height, | |
width, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
max_sequence_length=max_sequence_length, | |
) | |
self._guidance_scale = guidance_scale | |
self._joint_attention_kwargs = joint_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
try: | |
device = self._execution_device | |
except: | |
device = torch.device('cuda:0') | |
# 3. Encode prompt | |
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None | |
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
max_sequence_length=max_sequence_length, | |
lora_scale=lora_scale, | |
) | |
# 4. Prepare latent variables | |
num_channels_latents = self.transformer.in_channels // 4 | |
latents, latent_image_ids = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 5. Prepare timesteps | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
image_seq_len = latents.shape[1] | |
mu = calculate_shift( | |
image_seq_len, | |
self.scheduler.config.base_image_seq_len, | |
self.scheduler.config.max_image_seq_len, | |
self.scheduler.config.base_shift, | |
self.scheduler.config.max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, | |
num_inference_steps, | |
device, | |
timesteps, | |
sigmas, | |
mu=mu, | |
) | |
self._num_timesteps = len(timesteps) | |
# Handle guidance | |
guidance = torch.full([1], guidance_scale, device=device, dtype=dtype).expand(latents.shape[0]) # if self.transformer.params.guidance_embeds else None | |
# print(latent_image_ids.shape, text_ids.shape, pooled_prompt_embeds.shape) | |
# 6. Denoising loop | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
timestep = t.expand(latents.shape[0]).to(dtype) | |
noise_pred = self.transformer( | |
img=latents.to(dtype).to(device), | |
timesteps=(timestep / 1000).to(dtype), | |
guidance=guidance.to(dtype).to(device), | |
y=pooled_prompt_embeds.to(dtype).to(device), | |
txt=prompt_embeds.to(dtype).to(device), | |
txt_ids=text_ids.to(dtype).to(device), | |
img_ids=latent_image_ids.to(dtype).to(device), | |
) | |
# Yield intermediate result | |
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
image = self.vae.decode(latents_for_image, return_dict=False)[0] | |
# yield self.image_processor.postprocess(image, output_type=output_type)[0] | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
torch.cuda.empty_cache() | |
# Final image using good_vae | |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor | |
image = good_vae.decode(latents, return_dict=False)[0] | |
self.maybe_free_model_hooks() | |
torch.cuda.empty_cache() | |
yield self.image_processor.postprocess(image, output_type=output_type)[0] | |
class ModelSpec: | |
params: FluxParams | |
repo_id: str | |
repo_flow: str | |
repo_ae: str | |
repo_id_ae: str | |
ckpt_path: str | |
config = ModelSpec( | |
repo_id="TencentARC/flux-mini", | |
repo_flow="flux-mini.safetensors", | |
repo_id_ae="black-forest-labs/FLUX.1-dev", | |
repo_ae="ae.safetensors", | |
ckpt_path=None, | |
params=FluxParams( | |
in_channels=64, | |
vec_in_dim=768, | |
context_in_dim=4096, | |
hidden_size=3072, | |
mlp_ratio=4.0, | |
num_heads=24, | |
depth=5, | |
depth_single_blocks=10, | |
axes_dim=[16, 56, 56], | |
theta=10_000, | |
qkv_bias=True, | |
guidance_embed=True, | |
) | |
) | |
def load_flow_model2(config, device: str = "cuda", hf_download: bool = True): | |
if (config.ckpt_path is None | |
and config.repo_id is not None | |
and config.repo_flow is not None | |
and hf_download | |
): | |
ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors")) | |
else: | |
ckpt_path = config.ckpt_path | |
model = Flux(config.params) | |
if ckpt_path is not None: | |
sd = load_sft(ckpt_path, device=str(device)) | |
missing, unexpected = model.load_state_dict(sd, strict=True) | |
return model | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler") | |
good_vae = vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) | |
text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder", torch_dtype=dtype).to(device) | |
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer") | |
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=dtype).to(device) | |
tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2") | |
transformer = load_flow_model2(config, device).to(dtype).to(device) | |
pipe = FluxPipeline( | |
scheduler, | |
vae, | |
text_encoder, | |
tokenizer, | |
text_encoder_2, | |
tokenizer_2, | |
transformer | |
) | |
torch.cuda.empty_cache() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): | |
torch.cuda.empty_cache() | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
pass | |
return img, seed | |
examples = [ | |
"a lovely cat", | |
"thousands of luminous oysters on a shore reflecting and refracting the sunset", | |
"profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta," | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 520px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f"""# FLUX-Mini | |
A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://blackforestlabs.ai/) | |
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] | |
""") | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", | |
minimum=1, | |
maximum=15, | |
step=0.1, | |
value=3.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
gr.Examples( | |
examples = examples, | |
fn = infer, | |
inputs = [prompt], | |
outputs = [result, seed], | |
cache_examples="lazy" | |
) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn = infer, | |
inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], | |
outputs = [result, seed] | |
) | |
demo.launch() |