Sal-ONE's picture
Update app.py
798763c verified
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import random
# Function to load the selected Stable Diffusion model
def load_model(model_id):
"""Load the specified Stable Diffusion model."""
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipeline = pipeline.to("cpu")
return pipeline
# Initial model
model_options = [
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2-1"
]
current_model_id = model_options[0]
pipeline = load_model(current_model_id)
def switch_model(selected_model):
"""Switch the Stable Diffusion model."""
global pipeline
pipeline = load_model(selected_model)
return f"Model switched to: {selected_model}"
def generate_image(prompt, num_inference_steps=20, guidance_scale=7.5, seed=None):
"""Generate an image from a text prompt using Stable Diffusion."""
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
image = pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0]
return image
# Define the Gradio app layout
with gr.Blocks() as app:
gr.Markdown("# Stable Diffusion Gradio App\nGenerate stunning images from text prompts and switch between models!")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
label="Select Model",
choices=model_options,
value=current_model_id,
interactive=True
)
model_switch_status = gr.Textbox(label="Model Status", value=f"Current model: {current_model_id}", interactive=False)
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=2)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=30, value=20, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.5, step=0.5)
seed = gr.Number(label="Seed (Optional)", value=None)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
model_selector.change(
switch_model,
inputs=[model_selector],
outputs=model_switch_status
)
generate_btn.click(
generate_image,
inputs=[prompt, num_inference_steps, guidance_scale, seed],
outputs=output_image
)
# Run the app
if __name__ == "__main__":
app.launch()