test1 / api.py
nikajoon's picture
Update api.py
1762c84 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import threading
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from transformers import pipeline
from PIL import Image
import os
app = FastAPI()
# تنظیمات CORS
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:7860",
"https://nikajoon-test1.hf.space", # URL فضای شما در Hugging Face
"http://tomko.ir", # دامنه سایت شما
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
pipe = pipeline('text-generation', model='daspartho/prompt-extend')
def extend_prompt(prompt):
return pipe(prompt+',', num_return_sequences=1)[0]["generated_text"]
def text_it(inputs):
return extend_prompt(inputs)
def load_pipeline(use_cuda):
device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.cuda.max_memory_allocated(device=device)
torch.cuda.empty_cache()
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.enable_xformers_memory_efficient_attention()
pipe = pipe.to(device)
torch.cuda.empty_cache()
else:
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
pipe = pipe.to(device)
return pipe
def genie(prompt="sexy woman", use_details=False, steps=2, seed=398231747038484200, use_cuda=False):
pipe = load_pipeline(use_cuda)
generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed)
if use_details:
extended_prompt = extend_prompt(prompt)
else:
extended_prompt = prompt
int_image = pipe(prompt=extended_prompt, generator=generator, num_inference_steps=steps, guidance_scale=0.0).images[0]
return int_image, extended_prompt
@app.post("/generate")
async def generate_image(request: Request):
data = await request.json()
prompt = data.get("prompt")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
image, _ = genie(prompt)
image_path = "images/output.png"
os.makedirs(os.path.dirname(image_path), exist_ok=True)
image.save(image_path)
return JSONResponse(content={"image_url": f"/images/output.png"})
@app.get("/images/{image_name}", response_class=FileResponse)
async def get_image(image_name: str):
image_path = os.path.join("images", image_name)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(image_path)
@app.get("/", response_class=HTMLResponse)
async def read_root():
html_content = """
<html>
<head>
<title>My FastAPI App</title>
</head>
<body>
<h1>Hello from FastAPI!</h1>
<p>This is a sample application running on Hugging Face Spaces.</p>
</body>
</html>
"""
return html_content
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)