import spaces from accelerate import dispatch_model from fastapi import FastAPI, HTTPException, UploadFile, File from typing import Optional, Dict, Any import torch from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForText2Image ) import gradio as gr from PIL import Image import numpy as np import gc from io import BytesIO import base64 import functools app = FastAPI() # Comprehensive model registry MODELS = { "SDXL-Base": { "model_id": "stabilityai/stable-diffusion-xl-base-1.0", "pipeline": StableDiffusionXLPipeline, "supports_img2img": True, "parameters": { "num_inference_steps": {"min": 1, "max": 100, "default": 50}, "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, "height": {"min": 256, "max": 1024, "default": 512, "step": 64} } }, "SDXL-Turbo": { "model_id": "stabilityai/sdxl-turbo", "pipeline": AutoPipelineForText2Image, "supports_img2img": True, "parameters": { "num_inference_steps": {"min": 1, "max": 50, "default": 1}, "guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5}, "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, "height": {"min": 256, "max": 1024, "default": 512, "step": 64} } }, "SD-1.5": { "model_id": "runwayml/stable-diffusion-v1-5", "pipeline": StableDiffusionPipeline, "supports_img2img": True, "parameters": { "num_inference_steps": {"min": 1, "max": 50, "default": 30}, "guidance_scale": {"min": 1, "max": 20, "default": 7.5}, "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, "height": {"min": 256, "max": 1024, "default": 512, "step": 64} } }, "Waifu-Diffusion": { "model_id": "hakurei/waifu-diffusion", "pipeline": StableDiffusionPipeline, "supports_img2img": True, "parameters": { "num_inference_steps": {"min": 1, "max": 100, "default": 50}, "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, "height": {"min": 256, "max": 1024, "default": 512, "step": 64} } }, "Flux": { "model_id": "black-forest-labs/flux-1-1-dev", "pipeline": AutoPipelineForText2Image, "supports_img2img": True, "parameters": { "num_inference_steps": {"min": 1, "max": 50, "default": 25}, "guidance_scale": {"min": 1, "max": 15, "default": 7.5}, "width": {"min": 256, "max": 1024, "default": 512, "step": 64}, "height": {"min": 256, "max": 1024, "default": 512, "step": 64} } } } class ModelManager: def __init__(self): self.current_model = None self.current_pipeline = None self.model_cache: Dict[str, Any] = {} self._device = "cuda" if torch.cuda.is_available() else "cpu" self._dtype = torch.float16 if self._device == "cuda" else torch.float32 def _clear_memory(self): """Clear CUDA memory and garbage collect""" if self.current_pipeline is not None: del self.current_pipeline self.current_pipeline = None if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() @functools.lru_cache(maxsize=1) def get_model_config(self, model_id: str, pipeline_class): """Load and cache model configuration""" return pipeline_class.from_pretrained( model_id, torch_dtype=self._dtype, variant="fp16" if self._device == "cuda" else None, device_map="balanced" ) def load_model(self, model_name: str): """Load model with memory optimization""" if self.current_model != model_name: self._clear_memory() try: model_info = MODELS[model_name] self.current_pipeline = self.get_model_config( model_info["model_id"], model_info["pipeline"] ) if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'): self.current_pipeline.enable_xformers_memory_efficient_attention() # if self._device == "cuda": # self.current_pipeline.enable_model_cpu_offload() self.current_model = model_name except Exception as e: self._clear_memory() raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") return self.current_pipeline def unload_current_model(self): """Explicitly unload current model""" self._clear_memory() self.current_model = None def get_memory_status(self): """Get current memory usage status""" if not torch.cuda.is_available(): return {"status": "CPU Mode"} return { "total": torch.cuda.get_device_properties(0).total_memory / 1e9, "allocated": torch.cuda.memory_allocated() / 1e9, "cached": torch.cuda.memory_reserved() / 1e9, "free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 } class ModelContext: def __init__(self, model_name: str): self.model_name = model_name def __enter__(self): pipeline = model_manager.load_model(self.model_name) if hasattr(pipeline, 'reset_device_map'): pipeline.reset_device_map() # Check if the pipeline supports dispatch_model if hasattr(pipeline, 'state_dict'): dispatch_model(pipeline, device_map="auto") return pipeline def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: model_manager.unload_current_model() model_manager = ModelManager() @spaces.GPU def generate_image( model_name: str, prompt: str, height: int = 512, width: int = 512, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, reference_image: Optional[Image.Image] = None ) -> dict: try: with ModelContext(model_name) as pipeline: pre_mem = model_manager.get_memory_status() # Process reference image if provided if reference_image and MODELS[model_name]["supports_img2img"]: reference_image = reference_image.resize((width, height)) # Generate image generation_params = { "prompt": prompt, "height": height, "width": width, "num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"], "guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"] } if reference_image: generation_params["image"] = reference_image image = pipeline(**generation_params).images[0] # Convert to base64 buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode() post_mem = model_manager.get_memory_status() return { "status": "success", "image_base64": img_str, "memory": { "before": pre_mem, "after": post_mem } } except Exception as e: model_manager.unload_current_model() raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate") async def generate_image_endpoint( model_name: str, prompt: str, height: int = 512, width: int = 512, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, reference_image: UploadFile = File(None) ): ref_img = None if reference_image: content = await reference_image.read() ref_img = Image.open(BytesIO(content)) return generate_image( model_name=model_name, prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, reference_image=ref_img ) @app.get("/memory") async def get_memory_status(): return model_manager.get_memory_status() @app.post("/unload") async def unload_model(): model_manager.unload_current_model() return {"status": "success", "message": "Model unloaded"} def create_gradio_interface() -> gr.Blocks: with gr.Blocks() as interface: gr.Markdown("# Text-to-Image Generation Interface") with gr.Row(): with gr.Column(scale=2): model_dropdown = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Select Model" ) prompt = gr.Textbox( lines=3, label="Prompt", placeholder="Enter your image description here..." ) with gr.Row(): height = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Height" ) width = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Width" ) with gr.Row(): num_steps = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Number of Inference Steps" ) guidance = gr.Slider( minimum=1, maximum=15, value=7.5, step=0.1, label="Guidance Scale" ) reference_image = gr.Image( type="pil", label="Reference Image (optional)" ) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") unload_btn = gr.Button("Unload Model") with gr.Column(scale=2): output_image = gr.Image(label="Generated Image") memory_status = gr.JSON( label="Memory Status", value=model_manager.get_memory_status() ) def update_params(model_name: str) -> list: model_config = MODELS[model_name]["parameters"] return [ gr.update( minimum=model_config["height"]["min"], maximum=model_config["height"]["max"], value=model_config["height"]["default"], step=model_config["height"]["step"] ), gr.update( minimum=model_config["width"]["min"], maximum=model_config["width"]["max"], value=model_config["width"]["default"], step=model_config["width"]["step"] ), gr.update( minimum=model_config["num_inference_steps"]["min"], maximum=model_config["num_inference_steps"]["max"], value=model_config["num_inference_steps"]["default"] ), gr.update( minimum=model_config["guidance_scale"]["min"], maximum=model_config["guidance_scale"]["max"], value=model_config["guidance_scale"]["default"] ) ] def generate(model_name: str, prompt_text: str, h: int, w: int, steps: int, guide_scale: float, ref_img: Optional[Image.Image]) -> Image.Image: response = generate_image( model_name=model_name, prompt=prompt_text, height=h, width=w, num_inference_steps=steps, guidance_scale=guide_scale, reference_image=ref_img ) return Image.open(BytesIO(base64.b64decode(response["image_base64"]))) model_dropdown.change( update_params, inputs=[model_dropdown], outputs=[height, width, num_steps, guidance] ) generate_btn.click( generate, inputs=[ model_dropdown, prompt, height, width, num_steps, guidance, reference_image ], outputs=[output_image] ) unload_btn.click( lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()], outputs=[memory_status] ) return interface if __name__ == "__main__": import uvicorn from threading import Thread # Launch Gradio interface interface = create_gradio_interface() gradio_thread = Thread( target=interface.launch, kwargs={ "server_name": "0.0.0.0", "server_port": 7860, "share": False } ) gradio_thread.start() # Launch FastAPI uvicorn.run(app, host="0.0.0.0", port=8000)