import spaces import os import random import math import torch import numpy as np from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_euler_ancestral_discrete import ( EulerAncestralDiscreteScheduler, ) from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr try: from dotenv import load_dotenv load_dotenv() except: print("failed to import dotenv (this is not a problem on the production)") device = "cuda" if torch.cuda.is_available() else "cpu" HF_TOKEN = os.environ.get("HF_TOKEN") assert HF_TOKEN is not None IMAGE_MODEL_REPO_ID = os.environ.get( "IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0" ) DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None) assert DART_V3_REPO_ID is not None torch_dtype = torch.bfloat16 dart = AutoModelForCausalLM.from_pretrained( DART_V3_REPO_ID, torch_dtype=torch_dtype, token=HF_TOKEN, use_cache=True, ) dart = dart.eval() dart = dart.requires_grad_(False) dart = torch.compile(dart) tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID) pipe = StableDiffusionXLPipeline.from_pretrained( IMAGE_MODEL_REPO_ID, torch_dtype=torch_dtype, add_watermarker=False, custom_pipeline="lpw_stable_diffusion_xl" ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) pipe.unet.set_attn_processor(AttnProcessor2_0()) pipe = pipe.to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 TEMPLATE = ( "<|bos|>" # "<|rating:general|>" "{aspect_ratio}" "<|length:medium|>" # "original" # "" # "" ) def get_aspect_ratio(width: int, height: int) -> str: ar = width / height if ar <= 1 / math.sqrt(3): return "<|aspect_ratio:ultra_wide|>" elif ar <= 8 / 9: # return "<|aspect_ratio:wide|>" elif ar < 9 / 8: return "<|aspect_ratio:square|>" elif ar < math.sqrt(3): return "<|aspect_ratio:tall|>" else: return "<|aspect_ratio:ultra_tall|>" @torch.inference_mode def generate_prompt(aspect_ratio: str): input_ids = tokenizer.encode_plus( TEMPLATE.format(aspect_ratio=aspect_ratio), return_tensors="pt", ).input_ids print("input_ids:", input_ids) output_ids = dart.generate( input_ids, max_new_tokens=256, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, num_beams=1, )[0] generated = output_ids[len(input_ids) :] decoded = ", ".join([token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != ""]) print("decoded:", decoded) return decoded def format_prompt(prompt: str, prompt_suffix: str): return f"{prompt}, {prompt_suffix}" @spaces.GPU def generate_image( prompt: str, negative_prompt: str, generator, width: int, height: int, guidance_scale: float, num_inference_steps: int, ): image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def on_generate( suffix: str, negative_prompt: str, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) ar = get_aspect_ratio(width, height) prompt = generate_prompt(ar) prompt = format_prompt(prompt, suffix) print(prompt) image = generate_image( prompt, negative_prompt, generator, width, height, guidance_scale, num_inference_steps, ) return image, prompt, seed css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Random IllustriousXL """) with gr.Row(): run_button = gr.Button("Generate random", scale=0) result = gr.Image(label="Result", show_label=False) with gr.Accordion("Generation details", open=False): prompt_txt = gr.Textbox(label="Generated prompt", interactive=False) with gr.Accordion("Advanced Settings", open=False): prompt_suffix = gr.Text( label="Prompt suffix", visible=True, value="masterpiece, best quality, very aesthetic", ) negative_prompt = gr.Text( label="Negative prompt", placeholder="Enter a negative prompt", visible=True, value="(worst quality, bad quality, low quality:1.2), lowres, displeasing, very displeasing, bad anatomy, bad hands, extra digits, fewer digits, scan artifacts, signature, username, jpeg artifacts, retro, 2010s", ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=832, # Replace with defaults that work for your model ) height = gr.Slider( label="Height", minimum=512, maximum=MAX_IMAGE_SIZE, step=64, value=1152, # Replace with defaults that work for your model ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=10.0, step=0.5, value=7, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=20, maximum=50, step=1, value=25, ) gr.on( triggers=[run_button.click], fn=on_generate, inputs=[ prompt_suffix, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, prompt_txt, seed], ) demo.queue().launch()