STABLE-HAMSTER / app.py
prithivMLmods's picture
Update app.py
ddab8a4 verified
raw
history blame
16.2 kB
import os
import random
import uuid
import json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
# Use environment variables for flexibility
MODEL_ID = os.getenv("MODEL_ID", "Corcelio/mobius")
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
# Determine device and load model outside of function for efficiency
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=True,
add_watermarker=False,
).to(device)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
# Torch compile for potential speedup (experimental)
if USE_TORCH_COMPILE:
pipe.compile()
# CPU offloading for larger RAM capacity (experimental)
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
MAX_SEED = np.iinfo(np.int32).max
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=35, enable_queue=True)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 1,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 3,
num_inference_steps: int = 30,
randomize_seed: bool = False,
use_resolution_binning: bool = True,
num_images: int = 1, # Number of images to generate
progress=gr.Progress(track_tqdm=True),
):
seed = int(randomize_seed_fn(seed, randomize_seed))
generator = torch.Generator(device=device).manual_seed(seed)
# Improved options handling
options = {
"prompt": [prompt] * num_images,
"negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
"width": width,
"height": height,
"guidance_scale": guidance_scale,
"num_inference_steps": num_inference_steps,
"generator": generator,
"output_type": "pil",
}
# Use resolution binning for faster generation with less VRAM usage
if use_resolution_binning:
options["use_resolution_binning"] = True
# Generate images potentially in batches
images = []
for i in range(0, num_images, BATCH_SIZE):
batch_options = options.copy()
batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
if "negative_prompt" in batch_options:
batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
images.extend(pipe(**batch_options).images)
image_paths = [save_image(img) for img in images]
return image_paths, seed
examples = [
"a cat eating a piece of cheese",
"a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
"Ironman VS Hulk, ultrarealistic",
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
"An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
"Kids going to school, Anime style"
]
css = '''
.gradio-container{max-width: 700px !important}
h1{text-align:center}
footer {
visibility: hidden
}
.wheel-and-hamster {
--dur: 1s;
position: relative;
width: 12em;
height: 12em;
font-size: 14px;
}
.wheel,
.hamster,
.hamster div,
.spoke {
position: absolute;
}
.wheel,
.spoke {
border-radius: 50%;
top: 0;
left: 0;
width: 100%;
height: 100%;
}
.wheel {
background: radial-gradient(100% 100% at center,hsla(0,0%,60%,0) 47.8%,hsl(0,0%,60%) 48%);
z-index: 2;
}
.hamster {
animation: hamster var(--dur) ease-in-out infinite;
top: 50%;
left: calc(50% - 3.5em);
width: 7em;
height: 3.75em;
transform: rotate(4deg) translate(-0.8em,1.85em);
transform-origin: 50% 0;
z-index: 1;
}
.hamster__head {
animation: hamsterHead var(--dur) ease-in-out infinite;
background: hsl(30,90%,55%);
border-radius: 70% 30% 0 100% / 40% 25% 25% 60%;
box-shadow: 0 -0.25em 0 hsl(30,90%,80%) inset,
0.75em -1.55em 0 hsl(30,90%,90%) inset;
top: 0;
left: -2em;
width: 2.75em;
height: 2.5em;
transform-origin: 100% 50%;
}
.hamster__ear {
animation: hamsterEar var(--dur) ease-in-out infinite;
background: hsl(0,90%,85%);
border-radius: 50%;
box-shadow: -0.25em 0 hsl(30,90%,55%) inset;
top: -0.25em;
right: -0.25em;
width: 0.75em;
height: 0.75em;
transform-origin: 50% 75%;
}
.hamster__eye {
animation: hamsterEye var(--dur) linear infinite;
background-color: hsl(0,0%,0%);
border-radius: 50%;
top: 0.375em;
left: 1.25em;
width: 0.5em;
height: 0.5em;
}
.hamster__nose {
background: hsl(0,90%,75%);
border-radius: 35% 65% 85% 15% / 70% 50% 50% 30%;
top: 0.75em;
left: 0;
width: 0.2em;
height: 0.25em;
}
.hamster__body {
animation: hamsterBody var(--dur) ease-in-out infinite;
background: hsl(30,90%,90%);
border-radius: 50% 30% 50% 30% / 15% 60% 40% 40%;
box-shadow: 0.1em 0.75em 0 hsl(30,90%,55%) inset,
0.15em -0.5em 0 hsl(30,90%,80%) inset;
top: 0.25em;
left: 2em;
width: 4.5em;
height: 3em;
transform-origin: 17% 50%;
transform-style: preserve-3d;
}
.hamster__limb--fr,
.hamster__limb--fl {
clip-path: polygon(0 0,100% 0,70% 80%,60% 100%,0% 100%,40% 80%);
top: 2em;
left: 0.5em;
width: 1em;
height: 1.5em;
transform-origin: 50% 0;
}
.hamster__limb--fr {
animation: hamsterFRLimb var(--dur) linear infinite;
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
transform: rotate(15deg) translateZ(-1px);
}
.hamster__limb--fl {
animation: hamsterFLLimb var(--dur) linear infinite;
background: linear-gradient(hsl(30,90%,80%) 80%,hsl(0,90%,75%) 80%);
transform: rotate(-60deg) translateZ(-1px);
}
.hamster__limb--br,
.hamster__limb--bl {
clip-path: polygon(0 0,100% 0,100% 30%,70% 80%,60% 100%,40% 100%,30% 80%);
top: 2.3em;
left: 2.8em;
width: 1.25em;
height: 2.5em;
transform-origin: 50% 10%;
}
.hamster__limb--br {
animation: hamsterBRLimb var(--dur) linear infinite;
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
transform: rotate(45deg) translateZ(-1px);
}
.hamster__limb--bl {
animation: hamsterBLLimb var(--dur) linear infinite;
background: linear-gradient(hsl(0,90%,75%) 40%,hsl(30,90%,80%) 40%);
transform: rotate(-30deg) translateZ(-1px);
}
.hamster__tail {
animation: hamsterTail var(--dur) linear infinite;
background: hsl(0,90%,85%);
border-radius: 0.25em 50% 50% 0.25em;
box-shadow: 0.1em 0.5em 0 hsl(30,90%,55%) inset,
0.1em -0.25em 0 hsl(30,90%,90%) inset;
top: 3em;
left: 6em;
width: 0.75em;
height: 0.75em;
transform: rotate(30deg) translateZ(-1px);
}
.spoke {
--s: 0.2;
background: hsl(0,0%,100%);
box-shadow: 0 0 0 0.2em hsl(0,0%,0%);
left: calc(50% - var(--s)/2);
width: var(--s);
height: var(--s);
}
.spoke:nth-child(1) {
--rotation: 15deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(2) {
--rotation: 45deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(3) {
--rotation: 75deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(4) {
--rotation: 105deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(5) {
--rotation: 135deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(6) {
--rotation: 165deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(7) {
--rotation: 195deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(8) {
--rotation: 225deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(9) {
--rotation: 255deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(10) {
--rotation: 285deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(11) {
--rotation: 315deg;
transform: rotate(var(--rotation));
}
.spoke:nth-child(12) {
--rotation: 345deg;
transform: rotate(var(--rotation));
}
@keyframes hamster {
50% {
transform: rotate(-4deg) translate(-0.8em,1.85em);
}
}
@keyframes hamsterHead {
50% {
transform: rotate(-8deg);
}
}
@keyframes hamsterEye {
50% {
transform: translateY(0.1em);
}
}
@keyframes hamsterEar {
50% {
transform: rotate(8deg);
}
}
@keyframes hamsterBody {
50% {
transform: rotate(2deg);
}
}
@keyframes hamsterFRLimb {
8%,70% {
transform: rotate(15deg) translateZ(-1px);
}
33% {
transform: rotate(-60deg) translateZ(-1px);
}
83% {
transform: rotate(45deg) translateZ(-1px);
}
}
@keyframes hamsterFLLimb {
8%,70% {
transform: rotate(-60deg) translateZ(-1px);
}
33% {
transform: rotate(15deg) translateZ(-1px);
}
83% {
transform: rotate(-45deg) translateZ(-1px);
}
}
@keyframes hamsterBRLimb {
0%,50% {
transform: rotate(45deg) translateZ(-1px);
}
25% {
transform: rotate(-30deg) translateZ(-1px);
}
75% {
transform: rotate(60deg) translateZ(-1px);
}
}
@keyframes hamsterBLLimb {
0%,50% {
transform: rotate(-30deg) translateZ(-1px);
}
25% {
transform: rotate(45deg) translateZ(-1px);
}
75% {
transform: rotate(-45deg) translateZ(-1px);
}
}
@keyframes hamsterTail {
50% {
transform: rotate(-30deg) translateZ(-1px);
}
}
#wrapper {
position: relative;
height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.loading-container {
position: relative;
display: flex;
justify-content: center;
align-items: center;
width: 100%;
height: 100%;
top: 50%;
transform: translateY(-50%);
}
.hidden {
display: none;
}
'''
html = '''
<div id="wrapper">
<div class="loading-container">
<div class="wheel-and-hamster">
<div class="wheel">
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
<div class="spoke"></div>
</div>
<div class="hamster">
<div class="hamster__body">
<div class="hamster__head">
<div class="hamster__ear"></div>
<div class="hamster__eye"></div>
<div class="hamster__nose"></div>
</div>
<div class="hamster__limb hamster__limb--fr"></div>
<div class="hamster__limb hamster__limb--fl"></div>
<div class="hamster__limb hamster__limb--br"></div>
<div class="hamster__limb hamster__limb--bl"></div>
<div class="hamster__tail"></div>
</div>
</div>
</div>
</div>
</div>
<script>
// Wait for the Gradio app to load
document.addEventListener("DOMContentLoaded", function() {
const observer = new MutationObserver(function(mutationsList, observer) {
for (const mutation of mutationsList) {
if (mutation.type === "childList" && mutation.addedNodes.length > 0) {
// Check if Gradio has loaded by looking for a specific element
const gradioApp = document.querySelector("#root");
if (gradioApp) {
// Hide the loading animation and observer
document.querySelector(".loading-container").classList.add("hidden");
observer.disconnect();
}
}
}
});
// Start observing the body for changes
observer.observe(document.body, { childList: true, subtree: true });
});
</script>
'''
block = gr.Blocks(css=css)
with block as demo:
gr.HTML(html)
with gr.Column(elem_id="main-app"):
gr.Markdown(
"""
# Generate images with Stable Diffusion XL
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt",
lines=1,
)
negative_text = gr.Textbox(
label="Negative Prompt",
placeholder="Enter negative prompt",
lines=1,
)
negative_prompt_chk = gr.Checkbox(
label="Use Negative Prompt",
value=True
)
seed = gr.Number(
label="Seed",
value=1,
precision=0
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=False
)
width = gr.Slider(
label="Width",
value=1024,
minimum=64,
maximum=MAX_IMAGE_SIZE,
step=8
)
height = gr.Slider(
label="Height",
value=1024,
minimum=64,
maximum=MAX_IMAGE_SIZE,
step=8
)
guidance_scale = gr.Slider(
label="Guidance Scale",
value=3,
minimum=0,
maximum=50,
step=0.5
)
num_inference_steps = gr.Slider(
label="Number of Inference Steps",
value=30,
minimum=1,
maximum=100,
step=1
)
use_resolution_binning = gr.Checkbox(
label="Use Resolution Binning",
value=True
)
num_images = gr.Number(
label="Number of Images to Generate",
value=1,
minimum=1,
maximum=10,
step=1
)
generate_button = gr.Button("Generate Images")
with gr.Column():
gr.Label("Examples:")
for example in examples:
gr.Label(f"- {example}")
generated_images = gr.Image(
label="Generated Images",
type="PIL" # Display the PIL image
)
# Define the first gr.Image component
# generated_images1 = gr.Image(
# label="Generated Images 1",
# type="PIL", # Display the PIL image
# source=None
#)
# Define the second gr.Image component
#generated_images2 = gr.Image(
# label="Generated Images 2",
# type="PIL", # Display the PIL image
# source=None
# )
def generate_images_interface():
args = {
"prompt": text.value,
"negative_prompt": negative_text.value if negative_prompt_chk.value else "",
"use_negative_prompt": negative_prompt_chk.value,
"seed": seed.value,
"width": int(width.value),
"height": int(height.value),
"guidance_scale": float(guidance_scale.value),
"num_inference_steps": int(num_inference_steps.value),
"randomize_seed": randomize_seed.value,
"use_resolution_binning": use_resolution_binning.value,
"num_images": int(num_images.value),
"progress": gr.Progress()
}
image_paths, _ = generate(**args)
images = [Image.open(image_path) for image_path in image_paths]
return images
def on_generate_click():
generated_images.set_value(generate_images_interface())
gr.Interface(
fn=on_generate_click,
live=True,
title="Diffusion Generator",
description="Generate images using Stable Diffusion XL.",
layout="vertical",
blocking=True
).launch()