import os import random import gradio as gr import numpy as np import PIL.Image import torch from typing import List from diffusers.utils import numpy_to_pil from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS from previewer.modules import Previewer from gallery_history import fetch_gallery_history, show_gallery_history os.environ['TOKENIZERS_PARALLELISM'] = 'false' DESCRIPTION = "# Waves Weaves" DESCRIPTION += "\n

" if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶

" MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536")) USE_TORCH_COMPILE = False ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1" PREVIEW_IMAGES = True dtype = torch.float16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-ai/wuerstchen-prior-model-interpolated", torch_dtype=dtype) decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained("warp-ai/wuerstchen", torch_dtype=dtype) if ENABLE_CPU_OFFLOAD: prior_pipeline.enable_model_cpu_offload() decoder_pipeline.enable_model_cpu_offload() else: prior_pipeline.to(device) decoder_pipeline.to(device) if USE_TORCH_COMPILE: prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True) decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True) if PREVIEW_IMAGES: previewer = Previewer() previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"]) previewer.eval().requires_grad_(False).to(device).to(dtype) def callback_prior(i, t, latents): output = previewer(latents) output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy()) return output else: previewer = None callback_prior = None else: prior_pipeline = None decoder_pipeline = None def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def generate( prompt: str, negative_prompt: str = "", seed: int = 0, width: int = 1024, height: int = 1024, prior_num_inference_steps: int = 60, # prior_timesteps: List[float] = None, prior_guidance_scale: float = 4.0, decoder_num_inference_steps: int = 12, # decoder_timesteps: List[float] = None, decoder_guidance_scale: float = 0.0, num_images_per_prompt: int = 2, ) -> PIL.Image.Image: generator = torch.Generator().manual_seed(seed) prior_output = prior_pipeline( prompt=prompt, height=height, width=width, timesteps=DEFAULT_STAGE_C_TIMESTEPS, negative_prompt=negative_prompt, guidance_scale=prior_guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=generator, callback=callback_prior, ) if PREVIEW_IMAGES: for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)): r = next(prior_output) if isinstance(r, list): yield r prior_output = r decoder_output = decoder_pipeline( image_embeddings=prior_output.image_embeddings, prompt=prompt, num_inference_steps=decoder_num_inference_steps, # timesteps=decoder_timesteps, guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt, generator=generator, output_type="pil", ).images yield decoder_output examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", ] with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Group(): with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Imagine... 'A puppy', 'A Delicious Fruit Cake', 'Copacabana Beach'...", container=False, ) run_button = gr.Button("Weave", scale=0) result = gr.Gallery(label="Result", show_label=False) with gr.Accordion("Advanced options", open=False): negative_prompt = gr.Text( label="What I do NOT want", max_lines=1, placeholder="Uncheck seed to iterate and finetune.", value="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=1024, maximum=MAX_IMAGE_SIZE, step=512, value=1024, ) height = gr.Slider( label="Height", minimum=1024, maximum=MAX_IMAGE_SIZE, step=512, value=1024, ) num_images_per_prompt = gr.Slider( label="Number of Images", minimum=1, maximum=2, step=1, value=2, ) with gr.Row(): prior_guidance_scale = gr.Slider( label="Prior Guidance Scale", minimum=0, maximum=20, step=0.1, value=17.0, ) prior_num_inference_steps = gr.Slider( label="Prior Inference Steps", minimum=30, maximum=60, step=1, value=30, ) decoder_guidance_scale = gr.Slider( label="Decoder Guidance Scale", minimum=0, maximum=0, step=0.1, value=0.0, ) decoder_num_inference_steps = gr.Slider( label="Decoder Inference Steps", minimum=4, maximum=12, step=1, value=12, ) gr.Examples( examples=examples, inputs=prompt, outputs=result, fn=generate, cache_examples=CACHE_EXAMPLES, ) history = show_gallery_history() inputs = [ prompt, negative_prompt, seed, width, height, prior_num_inference_steps, # prior_timesteps, prior_guidance_scale, decoder_num_inference_steps, # decoder_timesteps, decoder_guidance_scale, num_images_per_prompt, ] prompt.submit( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=generate, inputs=inputs, outputs=result, api_name="run", ).then( fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False ) negative_prompt.submit( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=generate, inputs=inputs, outputs=result, api_name=False, ).then( fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False ) run_button.click( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=generate, inputs=inputs, outputs=result, api_name=False, ).then( fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False ) if __name__ == "__main__": demo.queue(max_size=20).launch()