|
|
|
|
|
import os |
|
import random |
|
import gradio as gr |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024')) |
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
if torch.cuda.is_available(): |
|
unet = UNet2DConditionModel.from_pretrained( |
|
"latent-consistency/lcm-ssd-1b", |
|
torch_dtype=torch.float16, |
|
variant="fp16" |
|
) |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"segmind/SSD-1B", |
|
unet=unet, |
|
torch_dtype=torch.float16, |
|
variant="fp16" |
|
) |
|
|
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
|
pipe.to(device) |
|
else: |
|
pipe = None |
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
|
|
def generate(prompt: str, |
|
negative_prompt: str = '', |
|
use_negative_prompt: bool = False, |
|
seed: int = 0, |
|
width: int = 1024, |
|
height: int = 1024, |
|
guidance_scale: float = 1.0, |
|
num_inference_steps: int = 6, |
|
secret_token: str = '') -> PIL.Image.Image: |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
if not use_negative_prompt: |
|
negative_prompt = None |
|
|
|
return pipe(prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
output_type='pil').images[0] |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
<div style="text-align: center; color: black;"> |
|
<p style="color: black;">This space is a REST API to programmatically generate MP4s using a LoRA.</p> |
|
<p style="color: black;">Please see the <a href="https://hotshot.co" target="_blank">README.md</a> for more information.</p> |
|
</div> |
|
</div>""") |
|
secret_token = gr.Text( |
|
label='Secret Token', |
|
max_lines=1, |
|
placeholder='Enter your secret token', |
|
) |
|
prompt = gr.Text( |
|
label='Prompt', |
|
show_label=False, |
|
max_lines=1, |
|
placeholder='Enter your prompt', |
|
container=False, |
|
) |
|
run_button = gr.Button('Run', scale=0) |
|
result = gr.Image(label='Result', show_label=False) |
|
|
|
use_negative_prompt = gr.Checkbox(label='Use negative prompt', value=False) |
|
negative_prompt = gr.Text( |
|
label='Negative prompt', |
|
max_lines=1, |
|
placeholder='Enter a negative prompt', |
|
visible=False, |
|
) |
|
seed = gr.Slider(label='Seed', |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=0) |
|
randomize_seed = gr.Checkbox(label='Randomize seed', value=True) |
|
|
|
width = gr.Slider( |
|
label='Width', |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
) |
|
height = gr.Slider( |
|
label='Height', |
|
minimum=256, |
|
maximum=MAX_IMAGE_SIZE, |
|
step=32, |
|
value=1024, |
|
) |
|
guidance_scale = gr.Slider( |
|
label='Guidance scale', |
|
minimum=1, |
|
maximum=20, |
|
step=0.1, |
|
value=1.0) |
|
num_inference_steps = gr.Slider( |
|
label='Number of inference steps', |
|
minimum=2, |
|
maximum=40, |
|
step=1, |
|
value=6) |
|
|
|
use_negative_prompt.change( |
|
fn=lambda x: gr.update(visible=x), |
|
inputs=use_negative_prompt, |
|
outputs=negative_prompt |
|
) |
|
|
|
inputs = [ |
|
prompt, |
|
negative_prompt, |
|
use_negative_prompt, |
|
seed, |
|
width, |
|
height, |
|
guidance_scale, |
|
num_inference_steps, |
|
secret_token, |
|
] |
|
prompt.submit( |
|
fn=randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed |
|
).then( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=result, |
|
api_name='run', |
|
) |
|
negative_prompt.submit( |
|
fn=randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed |
|
).then( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=result |
|
) |
|
run_button.click( |
|
fn=randomize_seed_fn, |
|
inputs=[seed, randomize_seed], |
|
outputs=seed |
|
).then( |
|
fn=generate, |
|
inputs=inputs, |
|
outputs=result |
|
) |
|
demo.queue(max_size=6).launch() |