Spaces:
Running
on
A10G
Running
on
A10G
import numpy as np | |
import PIL.Image | |
import torch | |
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline | |
import uvicorn | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import RedirectResponse, StreamingResponse | |
import io | |
import os | |
from pathlib import Path | |
from db import Database | |
import uuid | |
import logging | |
from fastapi import FastAPI, Request, HTTPException | |
from asyncio import Lock | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
MAX_SEED = np.iinfo(np.int32).max | |
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1" | |
SPACE_ID = os.environ.get("SPACE_ID", "") | |
DEV = os.environ.get("DEV", "0") == "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache") | |
IMGS_PATH = DB_PATH / "imgs" | |
DB_PATH.mkdir(exist_ok=True, parents=True) | |
IMGS_PATH.mkdir(exist_ok=True, parents=True) | |
database = Database(DB_PATH) | |
generate_lock = Lock() | |
dtype = torch.bfloat16 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
prior_pipeline = StableCascadePriorPipeline.from_pretrained( | |
"stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16 | |
).to(device) | |
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained( | |
"stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16 | |
).to(device) | |
if USE_TORCH_COMPILE: | |
prior_pipeline.prior = torch.compile( | |
prior_pipeline.prior, mode="reduce-overhead", fullgraph=True | |
) | |
decoder_pipeline.decoder = torch.compile( | |
decoder_pipeline.decoder, mode="max-autotune", fullgraph=True | |
) | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
prior_num_inference_steps: int = 20, | |
prior_guidance_scale: float = 4.0, | |
decoder_num_inference_steps: int = 10, | |
decoder_guidance_scale: float = 0.0, | |
num_images_per_prompt: int = 1, | |
) -> PIL.Image.Image: | |
generator = torch.Generator().manual_seed(seed) | |
prior_output = prior_pipeline( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_inference_steps=prior_num_inference_steps, | |
negative_prompt=negative_prompt, | |
guidance_scale=prior_guidance_scale, | |
num_images_per_prompt=num_images_per_prompt, | |
generator=generator, | |
) | |
decoder_output = decoder_pipeline( | |
image_embeddings=prior_output.image_embeddings, | |
prompt=prompt, | |
num_inference_steps=decoder_num_inference_steps, | |
# timesteps=decoder_timesteps, | |
guidance_scale=decoder_guidance_scale, | |
negative_prompt=negative_prompt, | |
generator=generator, | |
output_type="pil", | |
).images | |
return decoder_output[0] | |
app = FastAPI() | |
origins = [ | |
"https://huggingface.co", | |
"http://huggingface.co", | |
"https://huggingface.co/", | |
"http://huggingface.co/", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def validate_origin(request: Request, call_next): | |
if DEV: | |
return await call_next(request) | |
if request.headers.get("referer") not in origins: | |
raise HTTPException(status_code=403, detail="Forbidden") | |
return await call_next(request) | |
async def generate_image( | |
prompt: str, negative_prompt: str = "", seed: int = 2134213213 | |
): | |
cached_img = database.check(prompt, negative_prompt, seed) | |
if cached_img: | |
logging.info(f"Image found in cache: {cached_img[0]}") | |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg") | |
logging.info(f"Image not found in cache, generating new image") | |
async with generate_lock: | |
pil_image = generate(prompt, negative_prompt, seed) | |
img_id = str(uuid.uuid4()) | |
img_path = IMGS_PATH / f"{img_id}.jpg" | |
pil_image.save(img_path) | |
img_io = io.BytesIO() | |
pil_image.save(img_io, "JPEG") | |
img_io.seek(0) | |
database.insert(prompt, negative_prompt, str(img_path), seed) | |
return StreamingResponse(img_io, media_type="image/jpeg") | |
async def main(): | |
# redirect to https://huggingface.co/spaces/multimodalart/stable-cascade | |
return RedirectResponse( | |
"https://multimodalart-stable-cascade.hf.space/?__theme=system" | |
) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |