import os import time from io import BytesIO import uuid import torch import gradio as gr import spaces import numpy as np from einops import rearrange from PIL import Image, ExifTags from dataclasses import dataclass from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack, prepare_tokens from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5 import jax import jax.numpy as jnp from flax import nnx from jax import Array as Tensor from einops import repeat @dataclass class SamplingOptions: prompt: str width: int height: int num_steps: int guidance: float seed: int | None NSFW_THRESHOLD = 0.85 @spaces.GPU def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) ae = load_ae(name, device="cpu" if offload else device) # nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) # return model, ae, t5, clip, nsfw_classifier return nnx.split(model), nnx.split(ae), nnx.split(t5), t5.tokenizer, nnx.split(clip), clip.tokenizer, None @jax.jit def encode(ae,x): ae=nnx.merge(*ae) return ae.encode(x) def _generate(model, ae, t5, clip, x, t5_tokens, clip_tokens, num_steps, guidance, #init_image=None, #image2image_strength=0.0, shift=True): b,h,w,c=x.shape model=nnx.merge(*model) ae=nnx.merge(*ae) t5=nnx.merge(*t5) clip=nnx.merge(*clip) timesteps = get_schedule( num_steps, x.shape[-1] * x.shape[-2] // 4, shift=shift, ) # if init_image is not None: # t_idx = int((1 - image2image_strength) * num_steps) # t = timesteps[t_idx] # timesteps = timesteps[t_idx:] # x = t * x + (1.0 - t) * init_image.astype(x.dtype) inp = prepare(t5, clip, x, t5_tokens, clip_tokens) x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) x = unpack(x.astype(jnp.float32), h*8, w*8) x = ae.decode(x) return x generate=jax.jit(_generate, static_argnames=("num_steps","shift")) def prepare_tokens(t5_tokenizer, clip_tokenizer, prompt: str | list[str]) -> tuple[Tensor, Tensor]: if isinstance(prompt, str): prompt = [prompt] t5_tokens = t5_tokenizer( prompt, truncation=True, max_length=512, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="jax", )["input_ids"] clip_tokens = clip_tokenizer( prompt, truncation=True, max_length=77, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="jax", )["input_ids"] return t5_tokens, clip_tokens class FluxGenerator: def __init__(self, model_name: str, device: str, offload: bool): self.device = None self.offload = offload self.model_name = model_name self.is_schnell = model_name == "flux-schnell" self.model, self.ae, self.t5, self.t5_tokenizer, self.clip, self.clip_tokenizer, self.nsfw_classifier = get_models( model_name, device=self.device, offload=self.offload, is_schnell=self.is_schnell, ) self.key = jax.random.key(0) @spaces.GPU(duration=180) def generate_image( self, img_size, num_steps, guidance, seed, prompt, # init_image=None, # image2image_strength=0.0, add_sampling_metadata=True, ): seed = int(seed) if seed == -1: seed = None if img_size == "1,024x1,024": width, height = 1024, 1024 else: width, height = 512, 512 opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if opts.seed is None: # opts.seed = torch.Generator(device="cpu").seed() key,self.key=jax.random.split(self.key,2) opts.seed=jax.random.randint(key,(),0,2**30) print(f"Generating '{opts.prompt}' with seed {opts.seed}") t0 = time.perf_counter() # if init_image is not None: # if isinstance(init_image, np.ndarray): # init_image = jnp.asarray(init_image).astype(jnp.float32) / 255.0 # init_image = init_image[None] # # init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) # init_image = jax.image.resize(init_image, (opts.height, opts.width), method="lanczos5") # # if self.offload: # # self.ae.encoder.to(self.device) # # init_image = self.ae.encode(init_image) # init_image = encode(self.ae, init_image) # prepare input t5_tokens, clip_tokens = prepare_tokens(self.t5_tokenizer, self.clip_tokenizer, prompt=opts.prompt) x = get_noise( 1, opts.height, opts.width, device=None, dtype=jnp.bfloat16, seed=opts.seed, ) x = generate(self.model, self.ae, self.t5, self.clip, x, t5_tokens, clip_tokens, opts.num_steps, opts.guidance, shift=(not self.is_schnell)) t1 = time.perf_counter() # print(f"Done in {t1 - t0:.1f}s.") runtime = t1 - t0 # print(f"Done in {t1 - t0:.1f}s.") # bring into PIL format x= jnp.clip(x, -1, 1) # x = embed_watermark(x.astype(jnp.float32)) # x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray(np.asarray((127.5 * (x[0] + 1.0))).astype(np.uint8)) # img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) # nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0] if True: filename = f"output/gradio/{uuid.uuid4()}.jpg" os.makedirs(os.path.dirname(filename), exist_ok=True) exif_data = Image.Exif() # if init_image is None: exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" # else: # exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" exif_data[ExifTags.Base.Make] = "Black Forest Labs" exif_data[ExifTags.Base.Model] = self.model_name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) return img, runtime, str(opts.seed), filename, None else: return None, str(opts.seed), None, "Your generated image may contain NSFW content." @spaces.GPU(duration=300) def create_demo(model_name: str, device: str = "cuda", offload: bool = False): generator = FluxGenerator(model_name, device, offload) is_schnell = model_name == "flux-schnell" with open("./assets/banner.html") as f: banner = f.read() with gr.Blocks() as demo: with gr.Column(elem_id="app-container"): gr.HTML(f"""