Spaces:
Runtime error
Runtime error
import os | |
import random | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import torch | |
from typing import List | |
from diffusers.utils import numpy_to_pil | |
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline | |
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS | |
from previewer.modules import Previewer | |
from gallery_history import fetch_gallery_history, show_gallery_history | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
DESCRIPTION = "# Waves Weaves" | |
DESCRIPTION += "\n<p style=\"text-align: center\"></p>" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶</p>" | |
MAX_SEED = np.iinfo(np.int32).max | |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536")) | |
USE_TORCH_COMPILE = False | |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1" | |
PREVIEW_IMAGES = True | |
dtype = torch.float16 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-ai/wuerstchen-prior-model-interpolated", torch_dtype=dtype) | |
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained("warp-ai/wuerstchen", torch_dtype=dtype) | |
if ENABLE_CPU_OFFLOAD: | |
prior_pipeline.enable_model_cpu_offload() | |
decoder_pipeline.enable_model_cpu_offload() | |
else: | |
prior_pipeline.to(device) | |
decoder_pipeline.to(device) | |
if USE_TORCH_COMPILE: | |
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True) | |
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True) | |
if PREVIEW_IMAGES: | |
previewer = Previewer() | |
previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"]) | |
previewer.eval().requires_grad_(False).to(device).to(dtype) | |
def callback_prior(i, t, latents): | |
output = previewer(latents) | |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy()) | |
return output | |
else: | |
previewer = None | |
callback_prior = None | |
else: | |
prior_pipeline = None | |
decoder_pipeline = None | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
prior_num_inference_steps: int = 60, | |
# prior_timesteps: List[float] = None, | |
prior_guidance_scale: float = 4.0, | |
decoder_num_inference_steps: int = 12, | |
# decoder_timesteps: List[float] = None, | |
decoder_guidance_scale: float = 0.0, | |
num_images_per_prompt: int = 2, | |
) -> PIL.Image.Image: | |
generator = torch.Generator().manual_seed(seed) | |
prior_output = prior_pipeline( | |
prompt=prompt, | |
height=height, | |
width=width, | |
timesteps=DEFAULT_STAGE_C_TIMESTEPS, | |
negative_prompt=negative_prompt, | |
guidance_scale=prior_guidance_scale, | |
num_images_per_prompt=num_images_per_prompt, | |
generator=generator, | |
callback=callback_prior, | |
) | |
if PREVIEW_IMAGES: | |
for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)): | |
r = next(prior_output) | |
if isinstance(r, list): | |
yield r | |
prior_output = r | |
decoder_output = decoder_pipeline( | |
image_embeddings=prior_output.image_embeddings, | |
prompt=prompt, | |
num_inference_steps=decoder_num_inference_steps, | |
# timesteps=decoder_timesteps, | |
guidance_scale=decoder_guidance_scale, | |
negative_prompt=negative_prompt, | |
generator=generator, | |
output_type="pil", | |
).images | |
yield decoder_output | |
examples = [ | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
] | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
) | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Imagine... 'A puppy', 'A Delicious Fruit Cake', 'Copacabana Beach'...", | |
container=False, | |
) | |
run_button = gr.Button("Weave", scale=0) | |
result = gr.Gallery(label="Result", show_label=False) | |
with gr.Accordion("Advanced options", open=False): | |
negative_prompt = gr.Text( | |
label="What I do NOT want", | |
max_lines=1, | |
placeholder="Uncheck seed to iterate and finetune.", | |
value="lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=1024, | |
maximum=MAX_IMAGE_SIZE, | |
step=512, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=1024, | |
maximum=MAX_IMAGE_SIZE, | |
step=512, | |
value=1024, | |
) | |
num_images_per_prompt = gr.Slider( | |
label="Number of Images", | |
minimum=1, | |
maximum=2, | |
step=1, | |
value=2, | |
) | |
with gr.Row(): | |
prior_guidance_scale = gr.Slider( | |
label="Prior Guidance Scale", | |
minimum=0, | |
maximum=20, | |
step=0.1, | |
value=17.0, | |
) | |
prior_num_inference_steps = gr.Slider( | |
label="Prior Inference Steps", | |
minimum=30, | |
maximum=60, | |
step=1, | |
value=30, | |
) | |
decoder_guidance_scale = gr.Slider( | |
label="Decoder Guidance Scale", | |
minimum=0, | |
maximum=0, | |
step=0.1, | |
value=0.0, | |
) | |
decoder_num_inference_steps = gr.Slider( | |
label="Decoder Inference Steps", | |
minimum=4, | |
maximum=12, | |
step=1, | |
value=12, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=result, | |
fn=generate, | |
cache_examples=CACHE_EXAMPLES, | |
) | |
history = show_gallery_history() | |
inputs = [ | |
prompt, | |
negative_prompt, | |
seed, | |
width, | |
height, | |
prior_num_inference_steps, | |
# prior_timesteps, | |
prior_guidance_scale, | |
decoder_num_inference_steps, | |
# decoder_timesteps, | |
decoder_guidance_scale, | |
num_images_per_prompt, | |
] | |
prompt.submit( | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=generate, | |
inputs=inputs, | |
outputs=result, | |
api_name="run", | |
).then( | |
fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False | |
) | |
negative_prompt.submit( | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=generate, | |
inputs=inputs, | |
outputs=result, | |
api_name=False, | |
).then( | |
fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False | |
) | |
run_button.click( | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=generate, | |
inputs=inputs, | |
outputs=result, | |
api_name=False, | |
).then( | |
fn=fetch_gallery_history, inputs=[prompt, result], outputs=history, queue=False | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |