CyranoB's picture
Fix HF login
b4cd6e1
import os
import random
import uuid
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import StableDiffusion3Pipeline, DPMSolverMultistepScheduler, AutoencoderKL, StableDiffusion3Img2ImgPipeline
from transformers import T5EncoderModel, BitsAndBytesConfig
from huggingface_hub import login
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
login(token=huggingface_token)
DESCRIPTION = """# Stable Diffusion 3"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = False
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_pipeline():
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(
model_id,
#device_map="balanced",
torch_dtype=torch.float16
)
return pipe
aspect_ratios = {
"21:9": (21, 9),
"2:1": (2, 1),
"16:9": (16, 9),
"5:4": (5, 4),
"4:3": (4, 3),
"3:2": (3, 2),
"1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=64):
if aspect_ratio not in aspect_ratios:
raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")
width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
ratio = width_multiplier / height_multiplier
if mode == 'portrait':
# Swap the ratio for portrait mode
ratio = 1 / ratio
height = int((total_pixels / ratio) ** 0.5)
height -= height % divisibility
width = int(height * ratio)
width -= width % divisibility
while width * height > total_pixels:
height -= divisibility
width = int(height * ratio)
width -= width % divisibility
return width, height
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU
def generate(
prompt:str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
aspect: str = "1:1",
mode: str = "landscape",
guidance_scale: float = 7.5,
randomize_seed: bool = False,
num_inference_steps=30,
NUM_IMAGES_PER_PROMPT=1,
use_resolution_binning: bool = True,
progress=gr.Progress(track_tqdm=True),
):
pipe = load_pipeline()
pipe.to(device)
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator().manual_seed(seed)
if not use_negative_prompt:
negative_prompt = None # type: ignore
width, height = calculate_resolution(aspect, mode)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
output_type="pil",
).images
return output
examples = [
"Beautiful pixel art of a wizard with hovering text \"Achievement unlocked: Diffusion models can spell now\"",
"Frog sitting in a 1950s diner wearing a leather jacket and a top hat. on the table a giant burger and a small sign that says \"froggy fridays\"",
"This dreamlike digital art capture a vibrant kaleidoscopic bird in a rainforest",
"pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo saying \"SD3 rocks!\"",
"post-apocalyptic city wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic",
"a dark-armored warrior with ornate golden details, cloaked in a flowing black cape, wielding a radiant, fiery sword, standing amidst an ominous cloudy backdrop with dramatic lighting, exuding a menacing, powerful presence.",
"A wise old wizard with a long white beard, flowing robes, and a gnarled staff, casting a spell, photorealistic style",
"Design a film poster for a noir thriller set in 1940s Los Angeles, featuring a shadowy figure under a streetlamp and a foggy, mysterious ambiance.",
]
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
gr.HTML(
"""
<h1 style='text-align: center'>
Stable Diffusion 3
</h1>
"""
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Row():
aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape')
result = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
value = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
steps = gr.Slider(
label="Steps",
minimum=0,
maximum=60,
step=1,
value=30,
)
number_image = gr.Slider(
label="Number of Images",
minimum=1,
maximum=2,
step=1,
value=1,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.1,
maximum=10,
step=0.1,
value=7.0,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
use_negative_prompt.change(
fn=lambda x: gr.update(visible=x),
inputs=use_negative_prompt,
outputs=negative_prompt,
api_name=False,
)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=generate,
inputs=[
prompt,
negative_prompt,
use_negative_prompt,
seed,
aspect,
mode,
guidance_scale,
randomize_seed,
steps,
number_image,
],
outputs=[result],
api_name="run",
)
if __name__ == "__main__":
demo.queue().launch()