| | import os |
| | import subprocess |
| | from pathlib import Path |
| |
|
| | import gradio as gr |
| | import torch |
| |
|
| | from demo import SdmCompressionDemo |
| |
|
| | if __name__ == "__main__": |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | servicer = SdmCompressionDemo(device) |
| | example_list = servicer.get_example_list() |
| |
|
| | with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: |
| | gr.Markdown(Path('docs/header.md').read_text()) |
| | gr.Markdown(Path('docs/description.md').read_text()) |
| | with gr.Row(): |
| | with gr.Column(variant='panel', scale=30): |
| |
|
| | text = gr.Textbox(label="Input Prompt", max_lines=5, placeholder="Enter your prompt") |
| |
|
| | with gr.Row(equal_height=True): |
| | generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") |
| | generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") |
| |
|
| | with gr.Accordion("Advanced Settings", open=False): |
| | negative = gr.Textbox(label=f'Negative Prompt', placeholder=f'Enter aspects to remove (e.g., {"low quality"})') |
| | with gr.Row(): |
| | guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, minimum=4, maximum=11, step=0.5) |
| | steps = gr.Slider(label="Denoising Steps", value=25, minimum=10, maximum=75, step=5) |
| | seed = gr.Slider(0, 999999, label='Random Seed', value=1234, step=1) |
| |
|
| | with gr.Tab("Example Prompts"): |
| | examples = gr.Examples(examples=example_list, inputs=[text]) |
| |
|
| | with gr.Column(variant='panel',scale=35): |
| | |
| | gr.Markdown('<h2 align="center">Original Stable Diffusion 1.4</h2>') |
| | original_model_output = gr.Image(label="Original Model") |
| | with gr.Row(equal_height=True): |
| | with gr.Column(): |
| | original_model_test_time = gr.Textbox(value="", label="Inference Time (sec)") |
| | original_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_original), label="# Parameters") |
| | original_model_error = gr.Markdown() |
| | |
| |
|
| | with gr.Column(variant='panel',scale=35): |
| | |
| | gr.Markdown('<h2 align="center">Compressed Stable Diffusion (Ours)</h2>') |
| | compressed_model_output = gr.Image(label="Compressed Model") |
| | with gr.Row(equal_height=True): |
| | with gr.Column(): |
| | compressed_model_test_time = gr.Textbox(value="", label="Inference Time (sec)") |
| | compressed_model_params = gr.Textbox(value=servicer.get_sdm_params(servicer.pipe_compressed), label="# Parameters") |
| | compressed_model_error = gr.Markdown() |
| |
|
| | inputs = [text, negative, guidance_scale, steps, seed] |
| |
|
| | |
| | original_model_outputs = [original_model_output, original_model_error, original_model_test_time] |
| | text.submit(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs) |
| | generate_original_button.click(servicer.infer_original_model, inputs=inputs, outputs=original_model_outputs) |
| |
|
| | |
| | compressed_model_outputs = [compressed_model_output, compressed_model_error, compressed_model_test_time] |
| | text.submit(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs) |
| | generate_compressed_button.click(servicer.infer_compressed_model, inputs=inputs, outputs=compressed_model_outputs) |
| |
|
| | gr.Markdown(Path('docs/footer.md').read_text()) |
| |
|
| | |
| | demo.queue() |
| | demo.launch() |
| |
|