from typing import Literal, Optional, TYPE_CHECKING import numpy as np from fastapi import FastAPI from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field from platform import system if TYPE_CHECKING: from flux_pipeline import FluxPipeline if system() == "Windows": MAX_RAND = 2**16 - 1 else: MAX_RAND = 2**32 - 1 class AppState: model: "FluxPipeline" class FastAPIApp(FastAPI): state: AppState class LoraArgs(BaseModel): scale: Optional[float] = 1.0 path: Optional[str] = None name: Optional[str] = None action: Optional[Literal["load", "unload"]] = "load" class LoraLoadResponse(BaseModel): status: Literal["success", "error"] message: Optional[str] = None class GenerateArgs(BaseModel): prompt: str width: Optional[int] = Field(default=720) height: Optional[int] = Field(default=1024) num_steps: Optional[int] = Field(default=24) guidance: Optional[float] = Field(default=3.5) seed: Optional[int] = Field( default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND ) strength: Optional[float] = 1.0 init_image: Optional[str] = None app = FastAPIApp() @app.post("/generate") def generate(args: GenerateArgs): """ Generates an image from the Flux flow transformer. Args: args (GenerateArgs): Arguments for image generation: - `prompt`: The prompt used for image generation. - `width`: The width of the image. - `height`: The height of the image. - `num_steps`: The number of steps for the image generation. - `guidance`: The guidance for image generation, represents the influence of the prompt on the image generation. - `seed`: The seed for the image generation. - `strength`: strength for image generation, 0.0 - 1.0. Represents the percent of diffusion steps to run, setting the init_image as the noised latent at the given number of steps. - `init_image`: Base64 encoded image or path to image to use as the init image. Returns: StreamingResponse: The generated image as streaming jpeg bytes. """ result = app.state.model.generate(**args.model_dump()) return StreamingResponse(result, media_type="image/jpeg") @app.post("/lora", response_model=LoraLoadResponse) def lora_action(args: LoraArgs): """ Loads or unloads a LoRA checkpoint into / from the Flux flow transformer. Args: args (LoraArgs): Arguments for the LoRA action: - `scale`: The scaling factor for the LoRA weights. - `path`: The path to the LoRA checkpoint. - `name`: The name of the LoRA checkpoint. - `action`: The action to perform, either "load" or "unload". Returns: LoraLoadResponse: The status of the LoRA action. """ try: if args.action == "load": app.state.model.load_lora(args.path, args.scale, args.name) elif args.action == "unload": app.state.model.unload_lora(args.name if args.name else args.path) else: return JSONResponse( content={ "status": "error", "message": f"Invalid action, expected 'load' or 'unload', got {args.action}", }, status_code=400, ) except Exception as e: return JSONResponse( status_code=500, content={"status": "error", "message": str(e)} ) return JSONResponse(status_code=200, content={"status": "success"})