Spaces:
Running
Running
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, AutoencoderKL, KDPM2AncestralDiscreteScheduler, UNet2DConditionModel | |
from huggingface_hub import hf_hub_download | |
import spaces | |
from PIL import Image | |
import requests | |
from translatepy import Translator | |
import numpy as np | |
import random | |
translator = Translator() | |
# Constants | |
model = "Corcelio/mobius" | |
vae_model = "madebyollin/sdxl-vae-fp16-fix" | |
MAX_SEED = np.iinfo(np.int32).max | |
CSS = """ | |
.gradio-container { | |
max-width: 690px !important; | |
} | |
footer { | |
visibility: hidden; | |
} | |
""" | |
JS = """function () { | |
gradioURL = window.location.href | |
if (!gradioURL.endsWith('?__theme=dark')) { | |
window.location.replace(gradioURL + '?__theme=dark'); | |
} | |
}""" | |
# Load VAE component | |
vae = AutoencoderKL.from_pretrained( | |
vae_model, | |
torch_dtype=torch.float16 | |
) | |
# Ensure model and scheduler are initialized in GPU-enabled function | |
if torch.cuda.is_available(): | |
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet").to("cuda", torch.float16) | |
pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, unet=unet, torch_dtype=torch.float16).to("cuda") | |
pipe.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
# Function | |
def generate_image( | |
prompt, | |
negative="low quality", | |
width=1024, | |
height=1024, | |
seed=-1, | |
nums=1, | |
scale=1.5, | |
steps=30, | |
clip=3): | |
if seed == -1: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
prompt = str(translator.translate(prompt, 'English')) | |
print(f'prompt:{prompt}') | |
image = pipe( | |
prompt, | |
negative_prompt=negative, | |
width=width, | |
height=height, | |
guidance_scale=scale, | |
generator = generator, | |
num_inference_steps=steps, | |
num_images_per_prompt=nums, | |
clip_skip=clip, | |
).images | |
return image, seed | |
examples = [ | |
"a cat eating a piece of cheese", | |
"a ROBOT riding a BLUE horse on Mars, photorealistic", | |
"Ironman VS Hulk, ultrarealistic", | |
"a CUTE robot artist painting on an easel", | |
"Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k", | |
"An alien holding sign board contain word 'Flash', futuristic, neonpunk", | |
"Kids going to school, Anime style" | |
] | |
# Gradio Interface | |
with gr.Blocks(css=CSS, js=JS, theme="soft") as demo: | |
gr.HTML("<h1><center>Mobius💠</center></h1>") | |
gr.HTML("<p><center><a href='https://huggingface.co/Corcelio/mobius'>mobius</a> text-to-image generation</center><br><center>Adding default prompts to enhance.</center></p>") | |
with gr.Group(): | |
with gr.Row(): | |
prompt = gr.Textbox(label='Enter Your Prompt(Multi-Languages)', value="best quality, HD, aesthetic", scale=6) | |
submit = gr.Button(scale=1, variant='primary') | |
img = gr.Gallery(label='Mobius Generated Image',columns = 1, preview=True) | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Row(): | |
negative = gr.Textbox(label="Negative prompt", value="low quality, ugly, blurry, poor face, bad anatomy") | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=512, | |
maximum=1280, | |
step=8, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=512, | |
maximum=1280, | |
step=8, | |
value=1024, | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed (-1 Get Random)", | |
minimum=-1, | |
maximum=MAX_SEED, | |
step=1, | |
value=-1, | |
scale=2, | |
) | |
nums = gr.Slider( | |
label="Image Numbers", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=1, | |
scale=1, | |
) | |
with gr.Row(): | |
scale = gr.Slider( | |
label="Guidance", | |
minimum=3.5, | |
maximum=7, | |
step=0.1, | |
value=7, | |
) | |
steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=50, | |
) | |
clip = gr.Slider( | |
label="Clip Skip", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3, | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=prompt, | |
outputs=img, | |
fn=generate_image, | |
cache_examples="lazy", | |
) | |
prompt.submit(fn=generate_image, | |
inputs=[prompt, negative, width, height, seed, nums, scale, steps, clip], | |
outputs=img, | |
) | |
submit.click(fn=generate_image, | |
inputs=[prompt, negative, width, height, seed, nums, scale, steps, clip], | |
outputs=img, | |
) | |
demo.queue().launch() |