Safetensors
aredden's picture
Make lora loading api endpoint functional
fb3cdc4
from typing import Literal, Optional, TYPE_CHECKING
import numpy as np
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
from platform import system
if TYPE_CHECKING:
from flux_pipeline import FluxPipeline
if system() == "Windows":
MAX_RAND = 2**16 - 1
else:
MAX_RAND = 2**32 - 1
class AppState:
model: "FluxPipeline"
class FastAPIApp(FastAPI):
state: AppState
class LoraArgs(BaseModel):
scale: Optional[float] = 1.0
path: Optional[str] = None
name: Optional[str] = None
action: Optional[Literal["load", "unload"]] = "load"
class LoraLoadResponse(BaseModel):
status: Literal["success", "error"]
message: Optional[str] = None
class GenerateArgs(BaseModel):
prompt: str
width: Optional[int] = Field(default=720)
height: Optional[int] = Field(default=1024)
num_steps: Optional[int] = Field(default=24)
guidance: Optional[float] = Field(default=3.5)
seed: Optional[int] = Field(
default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND
)
strength: Optional[float] = 1.0
init_image: Optional[str] = None
app = FastAPIApp()
@app.post("/generate")
def generate(args: GenerateArgs):
"""
Generates an image from the Flux flow transformer.
Args:
args (GenerateArgs): Arguments for image generation:
- `prompt`: The prompt used for image generation.
- `width`: The width of the image.
- `height`: The height of the image.
- `num_steps`: The number of steps for the image generation.
- `guidance`: The guidance for image generation, represents the
influence of the prompt on the image generation.
- `seed`: The seed for the image generation.
- `strength`: strength for image generation, 0.0 - 1.0.
Represents the percent of diffusion steps to run,
setting the init_image as the noised latent at the
given number of steps.
- `init_image`: Base64 encoded image or path to image to use as the init image.
Returns:
StreamingResponse: The generated image as streaming jpeg bytes.
"""
result = app.state.model.generate(**args.model_dump())
return StreamingResponse(result, media_type="image/jpeg")
@app.post("/lora", response_model=LoraLoadResponse)
def lora_action(args: LoraArgs):
"""
Loads or unloads a LoRA checkpoint into / from the Flux flow transformer.
Args:
args (LoraArgs): Arguments for the LoRA action:
- `scale`: The scaling factor for the LoRA weights.
- `path`: The path to the LoRA checkpoint.
- `name`: The name of the LoRA checkpoint.
- `action`: The action to perform, either "load" or "unload".
Returns:
LoraLoadResponse: The status of the LoRA action.
"""
try:
if args.action == "load":
app.state.model.load_lora(args.path, args.scale, args.name)
elif args.action == "unload":
app.state.model.unload_lora(args.name if args.name else args.path)
else:
return JSONResponse(
content={
"status": "error",
"message": f"Invalid action, expected 'load' or 'unload', got {args.action}",
},
status_code=400,
)
except Exception as e:
return JSONResponse(
status_code=500, content={"status": "error", "message": str(e)}
)
return JSONResponse(status_code=200, content={"status": "success"})