nikajoon commited on
Commit
a204c65
1 Parent(s): 0a4ccdb

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +57 -3
api.py CHANGED
@@ -1,6 +1,12 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
4
 
5
  app = FastAPI()
6
 
@@ -10,7 +16,7 @@ origins = [
10
  "http://localhost:8000",
11
  "http://localhost:7860",
12
  "https://nikajoon-test1.hf.space",
13
- # "http://yourdomain.com", # دامنه سایت شما
14
  ]
15
 
16
  app.add_middleware(
@@ -21,6 +27,50 @@ app.add_middleware(
21
  allow_headers=["*"],
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @app.get("/", response_class=HTMLResponse)
25
  async def read_root():
26
  html_content = """
@@ -35,3 +85,7 @@ async def read_root():
35
  </html>
36
  """
37
  return html_content
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ import threading
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ from diffusers import DiffusionPipeline
9
+ from transformers import pipeline
10
 
11
  app = FastAPI()
12
 
 
16
  "http://localhost:8000",
17
  "http://localhost:7860",
18
  "https://nikajoon-test1.hf.space",
19
+ "http://tomko.ir", # دامنه سایت شما
20
  ]
21
 
22
  app.add_middleware(
 
27
  allow_headers=["*"],
28
  )
29
 
30
+ pipe = pipeline('text-generation', model='daspartho/prompt-extend')
31
+
32
+ def extend_prompt(prompt):
33
+ return pipe(prompt+',', num_return_sequences=1)[0]["generated_text"]
34
+
35
+ def text_it(inputs):
36
+ return extend_prompt(inputs)
37
+
38
+ def load_pipeline(use_cuda):
39
+ device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
40
+ if device == "cuda":
41
+ torch.cuda.max_memory_allocated(device=device)
42
+ torch.cuda.empty_cache()
43
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
44
+ pipe.enable_xformers_memory_efficient_attention()
45
+ pipe = pipe.to(device)
46
+ torch.cuda.empty_cache()
47
+ else:
48
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
49
+ pipe = pipe.to(device)
50
+ return pipe
51
+
52
+ def genie(prompt="sexy woman", use_details=False, steps=2, seed=398231747038484200, use_cuda=False):
53
+ pipe = load_pipeline(use_cuda)
54
+ generator = np.random.seed(0) if seed == 0 else torch.manual_seed(seed)
55
+ if use_details:
56
+ extended_prompt = extend_prompt(prompt)
57
+ else:
58
+ extended_prompt = prompt
59
+ int_image = pipe(prompt=extended_prompt, generator=generator, num_inference_steps=steps, guidance_scale=0.0).images[0]
60
+ return int_image, extended_prompt
61
+
62
+ @app.post("/generate")
63
+ async def generate_image(request: Request):
64
+ data = await request.json()
65
+ prompt = data.get("prompt")
66
+ if not prompt:
67
+ raise HTTPException(status_code=400, detail="Prompt is required")
68
+
69
+ image, _ = genie(prompt)
70
+ # ذخیره‌سازی تصویر و بازگرداندن URL یا داده‌های تصویر به عنوان پاسخ
71
+ image.save("output.png")
72
+ return JSONResponse(content={"image_url": "output.png"})
73
+
74
  @app.get("/", response_class=HTMLResponse)
75
  async def read_root():
76
  html_content = """
 
85
  </html>
86
  """
87
  return html_content
88
+
89
+ if __name__ == "__main__":
90
+ import uvicorn
91
+ uvicorn.run(app, host="0.0.0.0", port=7860)