Spaces:
Running
Running
import gradio as gr | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import random | |
# Function to load the selected Stable Diffusion model | |
def load_model(model_id): | |
"""Load the specified Stable Diffusion model.""" | |
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) | |
pipeline = pipeline.to("cpu") | |
return pipeline | |
# Initial model | |
model_options = [ | |
"stable-diffusion-v1-5/stable-diffusion-v1-5", | |
"stabilityai/stable-diffusion-2-1" | |
] | |
current_model_id = model_options[0] | |
pipeline = load_model(current_model_id) | |
def switch_model(selected_model): | |
"""Switch the Stable Diffusion model.""" | |
global pipeline | |
pipeline = load_model(selected_model) | |
return f"Model switched to: {selected_model}" | |
def generate_image(prompt, num_inference_steps=20, guidance_scale=7.5, seed=None): | |
"""Generate an image from a text prompt using Stable Diffusion.""" | |
if seed is not None: | |
generator = torch.manual_seed(seed) | |
else: | |
generator = None | |
image = pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0] | |
return image | |
# Define the Gradio app layout | |
with gr.Blocks() as app: | |
gr.Markdown("# Stable Diffusion Gradio App\nGenerate stunning images from text prompts and switch between models!") | |
with gr.Row(): | |
with gr.Column(): | |
model_selector = gr.Dropdown( | |
label="Select Model", | |
choices=model_options, | |
value=current_model_id, | |
interactive=True | |
) | |
model_switch_status = gr.Textbox(label="Model Status", value=f"Current model: {current_model_id}", interactive=False) | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=2) | |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=30, value=20, step=1) | |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.5, step=0.5) | |
seed = gr.Number(label="Seed (Optional)", value=None) | |
generate_btn = gr.Button("Generate Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image", type="pil") | |
model_selector.change( | |
switch_model, | |
inputs=[model_selector], | |
outputs=model_switch_status | |
) | |
generate_btn.click( | |
generate_image, | |
inputs=[prompt, num_inference_steps, guidance_scale, seed], | |
outputs=output_image | |
) | |
# Run the app | |
if __name__ == "__main__": | |
app.launch() | |