test1 / api.py
nikajoon's picture
Update api.py
a204c65 verified
raw
history blame
2.96 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import threading
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from transformers import pipeline
app = FastAPI()
# تنظیمات CORS
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:7860",
"https://nikajoon-test1.hf.space",
"http://tomko.ir", # دامنه سایت شما
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
pipe = pipeline('text-generation', model='daspartho/prompt-extend')
def extend_prompt(prompt):
return pipe(prompt+',', num_return_sequences=1)[0]["generated_text"]
def text_it(inputs):
return extend_prompt(inputs)
def load_pipeline(use_cuda):
device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.cuda.max_memory_allocated(device=device)
torch.cuda.empty_cache()
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.enable_xformers_memory_efficient_attention()
pipe = pipe.to(device)
torch.cuda.empty_cache()
else:
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
pipe = pipe.to(device)
return pipe
def genie(prompt="sexy woman", use_details=False, steps=2, seed=398231747038484200, use_cuda=False):
pipe = load_pipeline(use_cuda)
generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed)
if use_details:
extended_prompt = extend_prompt(prompt)
else:
extended_prompt = prompt
int_image = pipe(prompt=extended_prompt, generator=generator, num_inference_steps=steps, guidance_scale=0.0).images[0]
return int_image, extended_prompt
@app.post("/generate")
async def generate_image(request: Request):
data = await request.json()
prompt = data.get("prompt")
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
image, _ = genie(prompt)
# ذخیره‌سازی تصویر و بازگرداندن URL یا داده‌های تصویر به عنوان پاسخ
image.save("output.png")
return JSONResponse(content={"image_url": "output.png"})
@app.get("/", response_class=HTMLResponse)
async def read_root():
html_content = """
<html>
<head>
<title>My FastAPI App</title>
</head>
<body>
<h1>Hello from FastAPI!</h1>
<p>This is a sample application running on Hugging Face Spaces.</p>
</body>
</html>
"""
return html_content
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)