Spaces:
Running
Running
File size: 2,904 Bytes
4c519fd 5c239ba 9fbb486 47ab990 9fbb486 8f246ac 5c239ba 27660a3 5c239ba 47ab990 5c239ba 67f60f6 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 9fbb486 3750ff9 24eb369 c6fcf99 47ab990 c6fcf99 5c239ba 9fbb486 c6fcf99 9fbb486 4c519fd 5c239ba 4c519fd 5c239ba 9fbb486 5c239ba 9fbb486 5c239ba 9fbb486 5c239ba c6fcf99 5c239ba c6fcf99 9fbb486 5c239ba c6fcf99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from time import time
from statistics import mean
from fastapi import BackgroundTasks, FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from modules.details import rand_details
from modules.inference import generate_image
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
tasks = {}
def get_place_in_queue(task_id):
queued_tasks = list(task for task in tasks.values()
if task["status"] == "queued" or task["status"] == "processing")
queued_tasks.sort(key=lambda task: task["created_at"])
queued_task_ids = list(task["task_id"] for task in queued_tasks)
try:
return queued_task_ids.index(task_id) + 1
except:
return 0
def calculate_eta(task_id):
total_durations = list(task["completed_at"] - task["started_at"]
for task in tasks.values() if "completed_at" in task)
initial_place_in_queue = tasks[task_id]["initial_place_in_queue"]
if len(total_durations):
eta = initial_place_in_queue * mean(total_durations)
else:
eta = initial_place_in_queue * 40
return round(eta, 1)
def process_task(task_id):
if 'processing' in list(task['status'] for task in tasks.values()):
return
tasks[task_id]["status"] = "processing"
tasks[task_id]["started_at"] = time()
try:
tasks[task_id]["value"] = generate_image(tasks[task_id]["prompt"])
except Exception as ex:
tasks[task_id]["status"] = "failed"
tasks[task_id]["error"] = repr(ex)
else:
tasks[task_id]["status"] = "completed"
finally:
tasks[task_id]["completed_at"] = time()
queued_tasks = list(task for task in tasks.values() if task["status"] == "queued")
if queued_tasks:
print(f"Tasks remaining: {len(queued_tasks)}")
process_task(queued_tasks[0]["task_id"])
@app.head('/')
@app.get('/')
def index():
return FileResponse(path="static/index.html", media_type="text/html")
@app.get('/details')
def generate_details():
return rand_details()
@app.get('/task/create')
def create_task(background_tasks: BackgroundTasks, prompt: str = "покемон"):
created_at = time()
task_id = f"{str(created_at)}_{prompt}"
tasks[task_id] = {
"task_id": task_id,
"created_at": created_at,
"prompt": prompt,
"status": "queued",
"poll_count": 0,
}
tasks[task_id]["initial_place_in_queue"] = get_place_in_queue(task_id)
background_tasks.add_task(process_task, task_id)
return tasks[task_id]
@app.get('/task/poll')
def poll_task(task_id: str):
tasks[task_id]["place_in_queue"] = get_place_in_queue(task_id)
tasks[task_id]["eta"] = calculate_eta(task_id)
tasks[task_id]["poll_count"] += 1
return tasks[task_id]
|