import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import random

# Function to load the selected Stable Diffusion model
def load_model(model_id):
    """Load the specified Stable Diffusion model."""
    pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
    pipeline = pipeline.to("cpu")
    return pipeline

# Initial model
model_options = [
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    "stabilityai/stable-diffusion-2-1"
]
current_model_id = model_options[0]
pipeline = load_model(current_model_id)

def switch_model(selected_model):
    """Switch the Stable Diffusion model."""
    global pipeline
    pipeline = load_model(selected_model)
    return f"Model switched to: {selected_model}"

def generate_image(prompt, num_inference_steps=20, guidance_scale=7.5, seed=None):
    """Generate an image from a text prompt using Stable Diffusion."""
    if seed is not None:
        generator = torch.manual_seed(seed)
    else:
        generator = None

    image = pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0]
    return image

# Define the Gradio app layout
with gr.Blocks() as app:
    gr.Markdown("# Stable Diffusion Gradio App\nGenerate stunning images from text prompts and switch between models!")

    with gr.Row():
        with gr.Column():
            model_selector = gr.Dropdown(
                label="Select Model",
                choices=model_options,
                value=current_model_id,
                interactive=True
            )
            model_switch_status = gr.Textbox(label="Model Status", value=f"Current model: {current_model_id}", interactive=False)
            prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=2)
            num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=30, value=20, step=1)
            guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.5, step=0.5)
            seed = gr.Number(label="Seed (Optional)", value=None)
            generate_btn = gr.Button("Generate Image")

        with gr.Column():
            output_image = gr.Image(label="Generated Image", type="pil")

    model_selector.change(
        switch_model, 
        inputs=[model_selector], 
        outputs=model_switch_status
    )

    generate_btn.click(
        generate_image, 
        inputs=[prompt, num_inference_steps, guidance_scale, seed], 
        outputs=output_image
    )

# Run the app
if __name__ == "__main__":
    app.launch()