Spaces:
Sleeping
Sleeping
import spaces | |
from accelerate import dispatch_model | |
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from typing import Optional, Dict, Any | |
import torch | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionXLPipeline, | |
AutoPipelineForText2Image | |
) | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import gc | |
from io import BytesIO | |
import base64 | |
import functools | |
app = FastAPI() | |
# Comprehensive model registry | |
MODELS = { | |
"SDXL-Base": { | |
"model_id": "stabilityai/stable-diffusion-xl-base-1.0", | |
"pipeline": StableDiffusionXLPipeline, | |
"supports_img2img": True, | |
"parameters": { | |
"num_inference_steps": {"min": 1, "max": 100, "default": 50}, | |
"guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
"width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
"height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
} | |
}, | |
"SDXL-Turbo": { | |
"model_id": "stabilityai/sdxl-turbo", | |
"pipeline": AutoPipelineForText2Image, | |
"supports_img2img": True, | |
"parameters": { | |
"num_inference_steps": {"min": 1, "max": 50, "default": 1}, | |
"guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5}, | |
"width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
"height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
} | |
}, | |
"SD-1.5": { | |
"model_id": "runwayml/stable-diffusion-v1-5", | |
"pipeline": StableDiffusionPipeline, | |
"supports_img2img": True, | |
"parameters": { | |
"num_inference_steps": {"min": 1, "max": 50, "default": 30}, | |
"guidance_scale": {"min": 1, "max": 20, "default": 7.5}, | |
"width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
"height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
} | |
}, | |
"Waifu-Diffusion": { | |
"model_id": "hakurei/waifu-diffusion", | |
"pipeline": StableDiffusionPipeline, | |
"supports_img2img": True, | |
"parameters": { | |
"num_inference_steps": {"min": 1, "max": 100, "default": 50}, | |
"guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
"width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
"height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
} | |
}, | |
"Flux": { | |
"model_id": "black-forest-labs/flux-1-1-dev", | |
"pipeline": AutoPipelineForText2Image, | |
"supports_img2img": True, | |
"parameters": { | |
"num_inference_steps": {"min": 1, "max": 50, "default": 25}, | |
"guidance_scale": {"min": 1, "max": 15, "default": 7.5}, | |
"width": {"min": 256, "max": 1024, "default": 512, "step": 64}, | |
"height": {"min": 256, "max": 1024, "default": 512, "step": 64} | |
} | |
} | |
} | |
class ModelManager: | |
def __init__(self): | |
self.current_model = None | |
self.current_pipeline = None | |
self.model_cache: Dict[str, Any] = {} | |
self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
self._dtype = torch.float16 if self._device == "cuda" else torch.float32 | |
def _clear_memory(self): | |
"""Clear CUDA memory and garbage collect""" | |
if self.current_pipeline is not None: | |
del self.current_pipeline | |
self.current_pipeline = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
gc.collect() | |
def get_model_config(self, model_id: str, pipeline_class): | |
"""Load and cache model configuration""" | |
return pipeline_class.from_pretrained( | |
model_id, | |
torch_dtype=self._dtype, | |
variant="fp16" if self._device == "cuda" else None, | |
device_map="balanced" | |
) | |
def load_model(self, model_name: str): | |
"""Load model with memory optimization""" | |
if self.current_model != model_name: | |
self._clear_memory() | |
try: | |
model_info = MODELS[model_name] | |
self.current_pipeline = self.get_model_config( | |
model_info["model_id"], | |
model_info["pipeline"] | |
) | |
if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'): | |
self.current_pipeline.enable_xformers_memory_efficient_attention() | |
# if self._device == "cuda": | |
# self.current_pipeline.enable_model_cpu_offload() | |
self.current_model = model_name | |
except Exception as e: | |
self._clear_memory() | |
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") | |
return self.current_pipeline | |
def unload_current_model(self): | |
"""Explicitly unload current model""" | |
self._clear_memory() | |
self.current_model = None | |
def get_memory_status(self): | |
"""Get current memory usage status""" | |
if not torch.cuda.is_available(): | |
return {"status": "CPU Mode"} | |
return { | |
"total": torch.cuda.get_device_properties(0).total_memory / 1e9, | |
"allocated": torch.cuda.memory_allocated() / 1e9, | |
"cached": torch.cuda.memory_reserved() / 1e9, | |
"free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 | |
} | |
class ModelContext: | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
def __enter__(self): | |
pipeline = model_manager.load_model(self.model_name) | |
if hasattr(pipeline, 'reset_device_map'): | |
pipeline.reset_device_map() | |
# Check if the pipeline supports dispatch_model | |
if hasattr(pipeline, 'state_dict'): | |
dispatch_model(pipeline, device_map="auto") | |
return pipeline | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if exc_type is not None: | |
model_manager.unload_current_model() | |
model_manager = ModelManager() | |
def generate_image( | |
model_name: str, | |
prompt: str, | |
height: int = 512, | |
width: int = 512, | |
num_inference_steps: Optional[int] = None, | |
guidance_scale: Optional[float] = None, | |
reference_image: Optional[Image.Image] = None | |
) -> dict: | |
try: | |
with ModelContext(model_name) as pipeline: | |
pre_mem = model_manager.get_memory_status() | |
# Process reference image if provided | |
if reference_image and MODELS[model_name]["supports_img2img"]: | |
reference_image = reference_image.resize((width, height)) | |
# Generate image | |
generation_params = { | |
"prompt": prompt, | |
"height": height, | |
"width": width, | |
"num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"], | |
"guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"] | |
} | |
if reference_image: | |
generation_params["image"] = reference_image | |
image = pipeline(**generation_params).images[0] | |
# Convert to base64 | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
post_mem = model_manager.get_memory_status() | |
return { | |
"status": "success", | |
"image_base64": img_str, | |
"memory": { | |
"before": pre_mem, | |
"after": post_mem | |
} | |
} | |
except Exception as e: | |
model_manager.unload_current_model() | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_image_endpoint( | |
model_name: str, | |
prompt: str, | |
height: int = 512, | |
width: int = 512, | |
num_inference_steps: Optional[int] = None, | |
guidance_scale: Optional[float] = None, | |
reference_image: UploadFile = File(None) | |
): | |
ref_img = None | |
if reference_image: | |
content = await reference_image.read() | |
ref_img = Image.open(BytesIO(content)) | |
return generate_image( | |
model_name=model_name, | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
reference_image=ref_img | |
) | |
async def get_memory_status(): | |
return model_manager.get_memory_status() | |
async def unload_model(): | |
model_manager.unload_current_model() | |
return {"status": "success", "message": "Model unloaded"} | |
def create_gradio_interface() -> gr.Blocks: | |
with gr.Blocks() as interface: | |
gr.Markdown("# Text-to-Image Generation Interface") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_dropdown = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Select Model" | |
) | |
prompt = gr.Textbox( | |
lines=3, | |
label="Prompt", | |
placeholder="Enter your image description here..." | |
) | |
with gr.Row(): | |
height = gr.Slider( | |
minimum=256, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Height" | |
) | |
width = gr.Slider( | |
minimum=256, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Width" | |
) | |
with gr.Row(): | |
num_steps = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Number of Inference Steps" | |
) | |
guidance = gr.Slider( | |
minimum=1, | |
maximum=15, | |
value=7.5, | |
step=0.1, | |
label="Guidance Scale" | |
) | |
reference_image = gr.Image( | |
type="pil", | |
label="Reference Image (optional)" | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate", variant="primary") | |
unload_btn = gr.Button("Unload Model") | |
with gr.Column(scale=2): | |
output_image = gr.Image(label="Generated Image") | |
memory_status = gr.JSON( | |
label="Memory Status", | |
value=model_manager.get_memory_status() | |
) | |
def update_params(model_name: str) -> list: | |
model_config = MODELS[model_name]["parameters"] | |
return [ | |
gr.update( | |
minimum=model_config["height"]["min"], | |
maximum=model_config["height"]["max"], | |
value=model_config["height"]["default"], | |
step=model_config["height"]["step"] | |
), | |
gr.update( | |
minimum=model_config["width"]["min"], | |
maximum=model_config["width"]["max"], | |
value=model_config["width"]["default"], | |
step=model_config["width"]["step"] | |
), | |
gr.update( | |
minimum=model_config["num_inference_steps"]["min"], | |
maximum=model_config["num_inference_steps"]["max"], | |
value=model_config["num_inference_steps"]["default"] | |
), | |
gr.update( | |
minimum=model_config["guidance_scale"]["min"], | |
maximum=model_config["guidance_scale"]["max"], | |
value=model_config["guidance_scale"]["default"] | |
) | |
] | |
def generate(model_name: str, prompt_text: str, h: int, w: int, steps: int, guide_scale: float, ref_img: Optional[Image.Image]) -> Image.Image: | |
response = generate_image( | |
model_name=model_name, | |
prompt=prompt_text, | |
height=h, | |
width=w, | |
num_inference_steps=steps, | |
guidance_scale=guide_scale, | |
reference_image=ref_img | |
) | |
return Image.open(BytesIO(base64.b64decode(response["image_base64"]))) | |
model_dropdown.change( | |
update_params, | |
inputs=[model_dropdown], | |
outputs=[height, width, num_steps, guidance] | |
) | |
generate_btn.click( | |
generate, | |
inputs=[ | |
model_dropdown, | |
prompt, | |
height, | |
width, | |
num_steps, | |
guidance, | |
reference_image | |
], | |
outputs=[output_image] | |
) | |
unload_btn.click( | |
lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()], | |
outputs=[memory_status] | |
) | |
return interface | |
if __name__ == "__main__": | |
import uvicorn | |
from threading import Thread | |
# Launch Gradio interface | |
interface = create_gradio_interface() | |
gradio_thread = Thread( | |
target=interface.launch, | |
kwargs={ | |
"server_name": "0.0.0.0", | |
"server_port": 7860, | |
"share": False | |
} | |
) | |
gradio_thread.start() | |
# Launch FastAPI | |
uvicorn.run(app, host="0.0.0.0", port=8000) |