rawc0der's picture
fix dispatch
80194d4
import spaces
from accelerate import dispatch_model
from fastapi import FastAPI, HTTPException, UploadFile, File
from typing import Optional, Dict, Any
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
AutoPipelineForText2Image
)
import gradio as gr
from PIL import Image
import numpy as np
import gc
from io import BytesIO
import base64
import functools
app = FastAPI()
# Comprehensive model registry
MODELS = {
"SDXL-Base": {
"model_id": "stabilityai/stable-diffusion-xl-base-1.0",
"pipeline": StableDiffusionXLPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 100, "default": 50},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"SDXL-Turbo": {
"model_id": "stabilityai/sdxl-turbo",
"pipeline": AutoPipelineForText2Image,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 1},
"guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"SD-1.5": {
"model_id": "runwayml/stable-diffusion-v1-5",
"pipeline": StableDiffusionPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 30},
"guidance_scale": {"min": 1, "max": 20, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"Waifu-Diffusion": {
"model_id": "hakurei/waifu-diffusion",
"pipeline": StableDiffusionPipeline,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 100, "default": 50},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
},
"Flux": {
"model_id": "black-forest-labs/flux-1-1-dev",
"pipeline": AutoPipelineForText2Image,
"supports_img2img": True,
"parameters": {
"num_inference_steps": {"min": 1, "max": 50, "default": 25},
"guidance_scale": {"min": 1, "max": 15, "default": 7.5},
"width": {"min": 256, "max": 1024, "default": 512, "step": 64},
"height": {"min": 256, "max": 1024, "default": 512, "step": 64}
}
}
}
class ModelManager:
def __init__(self):
self.current_model = None
self.current_pipeline = None
self.model_cache: Dict[str, Any] = {}
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._dtype = torch.float16 if self._device == "cuda" else torch.float32
def _clear_memory(self):
"""Clear CUDA memory and garbage collect"""
if self.current_pipeline is not None:
del self.current_pipeline
self.current_pipeline = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
@functools.lru_cache(maxsize=1)
def get_model_config(self, model_id: str, pipeline_class):
"""Load and cache model configuration"""
return pipeline_class.from_pretrained(
model_id,
torch_dtype=self._dtype,
variant="fp16" if self._device == "cuda" else None,
device_map="balanced"
)
def load_model(self, model_name: str):
"""Load model with memory optimization"""
if self.current_model != model_name:
self._clear_memory()
try:
model_info = MODELS[model_name]
self.current_pipeline = self.get_model_config(
model_info["model_id"],
model_info["pipeline"]
)
if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'):
self.current_pipeline.enable_xformers_memory_efficient_attention()
# if self._device == "cuda":
# self.current_pipeline.enable_model_cpu_offload()
self.current_model = model_name
except Exception as e:
self._clear_memory()
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
return self.current_pipeline
def unload_current_model(self):
"""Explicitly unload current model"""
self._clear_memory()
self.current_model = None
def get_memory_status(self):
"""Get current memory usage status"""
if not torch.cuda.is_available():
return {"status": "CPU Mode"}
return {
"total": torch.cuda.get_device_properties(0).total_memory / 1e9,
"allocated": torch.cuda.memory_allocated() / 1e9,
"cached": torch.cuda.memory_reserved() / 1e9,
"free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9
}
class ModelContext:
def __init__(self, model_name: str):
self.model_name = model_name
def __enter__(self):
pipeline = model_manager.load_model(self.model_name)
if hasattr(pipeline, 'reset_device_map'):
pipeline.reset_device_map()
# Check if the pipeline supports dispatch_model
if hasattr(pipeline, 'state_dict'):
dispatch_model(pipeline, device_map="auto")
return pipeline
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
model_manager.unload_current_model()
model_manager = ModelManager()
@spaces.GPU
def generate_image(
model_name: str,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
reference_image: Optional[Image.Image] = None
) -> dict:
try:
with ModelContext(model_name) as pipeline:
pre_mem = model_manager.get_memory_status()
# Process reference image if provided
if reference_image and MODELS[model_name]["supports_img2img"]:
reference_image = reference_image.resize((width, height))
# Generate image
generation_params = {
"prompt": prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"],
"guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"]
}
if reference_image:
generation_params["image"] = reference_image
image = pipeline(**generation_params).images[0]
# Convert to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
post_mem = model_manager.get_memory_status()
return {
"status": "success",
"image_base64": img_str,
"memory": {
"before": pre_mem,
"after": post_mem
}
}
except Exception as e:
model_manager.unload_current_model()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate_image_endpoint(
model_name: str,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: Optional[int] = None,
guidance_scale: Optional[float] = None,
reference_image: UploadFile = File(None)
):
ref_img = None
if reference_image:
content = await reference_image.read()
ref_img = Image.open(BytesIO(content))
return generate_image(
model_name=model_name,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
reference_image=ref_img
)
@app.get("/memory")
async def get_memory_status():
return model_manager.get_memory_status()
@app.post("/unload")
async def unload_model():
model_manager.unload_current_model()
return {"status": "success", "message": "Model unloaded"}
def create_gradio_interface() -> gr.Blocks:
with gr.Blocks() as interface:
gr.Markdown("# Text-to-Image Generation Interface")
with gr.Row():
with gr.Column(scale=2):
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Select Model"
)
prompt = gr.Textbox(
lines=3,
label="Prompt",
placeholder="Enter your image description here..."
)
with gr.Row():
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Height"
)
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Width"
)
with gr.Row():
num_steps = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Number of Inference Steps"
)
guidance = gr.Slider(
minimum=1,
maximum=15,
value=7.5,
step=0.1,
label="Guidance Scale"
)
reference_image = gr.Image(
type="pil",
label="Reference Image (optional)"
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
unload_btn = gr.Button("Unload Model")
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image")
memory_status = gr.JSON(
label="Memory Status",
value=model_manager.get_memory_status()
)
def update_params(model_name: str) -> list:
model_config = MODELS[model_name]["parameters"]
return [
gr.update(
minimum=model_config["height"]["min"],
maximum=model_config["height"]["max"],
value=model_config["height"]["default"],
step=model_config["height"]["step"]
),
gr.update(
minimum=model_config["width"]["min"],
maximum=model_config["width"]["max"],
value=model_config["width"]["default"],
step=model_config["width"]["step"]
),
gr.update(
minimum=model_config["num_inference_steps"]["min"],
maximum=model_config["num_inference_steps"]["max"],
value=model_config["num_inference_steps"]["default"]
),
gr.update(
minimum=model_config["guidance_scale"]["min"],
maximum=model_config["guidance_scale"]["max"],
value=model_config["guidance_scale"]["default"]
)
]
def generate(model_name: str, prompt_text: str, h: int, w: int, steps: int, guide_scale: float, ref_img: Optional[Image.Image]) -> Image.Image:
response = generate_image(
model_name=model_name,
prompt=prompt_text,
height=h,
width=w,
num_inference_steps=steps,
guidance_scale=guide_scale,
reference_image=ref_img
)
return Image.open(BytesIO(base64.b64decode(response["image_base64"])))
model_dropdown.change(
update_params,
inputs=[model_dropdown],
outputs=[height, width, num_steps, guidance]
)
generate_btn.click(
generate,
inputs=[
model_dropdown,
prompt,
height,
width,
num_steps,
guidance,
reference_image
],
outputs=[output_image]
)
unload_btn.click(
lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()],
outputs=[memory_status]
)
return interface
if __name__ == "__main__":
import uvicorn
from threading import Thread
# Launch Gradio interface
interface = create_gradio_interface()
gradio_thread = Thread(
target=interface.launch,
kwargs={
"server_name": "0.0.0.0",
"server_port": 7860,
"share": False
}
)
gradio_thread.start()
# Launch FastAPI
uvicorn.run(app, host="0.0.0.0", port=8000)