from dataclasses import dataclass from typing import Union, Optional, List, Any, Dict import gradio as gr import numpy as np import random import spaces import torch from safetensors.torch import load_file as load_sft from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL, FluxPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from model import Flux, FluxParams def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu def retrieve_timesteps( scheduler, num_inference_steps: Optional = None, device: Optional = None, timesteps: Optional = None, sigmas: Optional = None, **kwargs, ): if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps @torch.inference_mode() def flux_pipe_call_that_returns_an_iterable_of_images( self, prompt = None, prompt_2 = None, height = None, width = None, num_inference_steps: int = 28, timesteps = None, guidance_scale: float = 3.5, num_images_per_prompt = 1, generator = None, latents = None, prompt_embeds = None, pooled_prompt_embeds = None, output_type = "pil", return_dict = True, joint_attention_kwargs = None, max_sequence_length = 512, good_vae = None, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, prompt_2, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) try: device = self._execution_device except: device = torch.device('cuda:0') # 3. Encode prompt lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4. Prepare latent variables num_channels_latents = self.transformer.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) self._num_timesteps = len(timesteps) # Handle guidance guidance = torch.full([1], guidance_scale, device=device, dtype=dtype).expand(latents.shape[0]) # if self.transformer.params.guidance_embeds else None # print(latent_image_ids.shape, text_ids.shape, pooled_prompt_embeds.shape) # 6. Denoising loop for i, t in enumerate(timesteps): if self.interrupt: continue timestep = t.expand(latents.shape[0]).to(dtype) noise_pred = self.transformer( img=latents.to(dtype).to(device), timesteps=(timestep / 1000).to(dtype), guidance=guidance.to(dtype).to(device), y=pooled_prompt_embeds.to(dtype).to(device), txt=prompt_embeds.to(dtype).to(device), txt_ids=text_ids.to(dtype).to(device), img_ids=latent_image_ids.to(dtype).to(device), ) # Yield intermediate result latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents_for_image, return_dict=False)[0] # yield self.image_processor.postprocess(image, output_type=output_type)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] torch.cuda.empty_cache() # Final image using good_vae latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor image = good_vae.decode(latents, return_dict=False)[0] self.maybe_free_model_hooks() torch.cuda.empty_cache() yield self.image_processor.postprocess(image, output_type=output_type)[0] @dataclass class ModelSpec: params: FluxParams repo_id: str repo_flow: str repo_ae: str repo_id_ae: str ckpt_path: str config = ModelSpec( repo_id="TencentARC/flux-mini", repo_flow="flux-mini.safetensors", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_ae="ae.safetensors", ckpt_path=None, params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=5, depth_single_blocks=10, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ) ) def load_flow_model2(config, device: str = "cuda", hf_download: bool = True): if (config.ckpt_path is None and config.repo_id is not None and config.repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(config.repo_id, config.repo_flow.replace("sft", "safetensors")) else: ckpt_path = config.ckpt_path model = Flux(config.params) if ckpt_path is not None: sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=True) return model dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler") good_vae = vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device) text_encoder = CLIPTextModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder", torch_dtype=dtype).to(device) tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer") text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="text_encoder_2", torch_dtype=dtype).to(device) tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer_2") transformer = load_flow_model2(config, device).to(dtype).to(device) pipe = FluxPipeline( scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer ) torch.cuda.empty_cache() MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) @spaces.GPU(duration=30) def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)): torch.cuda.empty_cache() if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, output_type="pil", good_vae=good_vae, ): pass return img, seed examples = [ "a lovely cat", "thousands of luminous oysters on a shore reflecting and refracting the sunset", "profile of sad Socrates, full body, high detail, dramatic scene, Epic dynamic action, wide angle, cinematic, hyper realistic, concept art, warm muted tones as painted by Bernie Wrightson, Frank Frazetta," ] css=""" #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f"""# FLUX-Mini A 3.2B param rectified flow transformer distilled from [FLUX.1 [dev]](https://blackforestlabs.ai/) [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] """) with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=28, ) gr.Examples( examples = examples, fn = infer, inputs = [prompt], outputs = [result, seed], cache_examples="lazy" ) gr.on( triggers=[run_button.click, prompt.submit], fn = infer, inputs = [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs = [result, seed] ) demo.launch()