Spaces:
Paused
Paused
File size: 8,616 Bytes
d69879c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
"""
FacePoke API
Author: Julian Bilcke
Date: September 30, 2024
"""
import sys
import asyncio
import hashlib
from aiohttp import web, WSMsgType
import json
import uuid
import logging
import os
import zipfile
import signal
from typing import Dict, Any, List, Optional
import base64
import io
from PIL import Image
import numpy as np
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set asyncio logger to DEBUG level
logging.getLogger("asyncio").setLevel(logging.DEBUG)
logger.debug(f"Python version: {sys.version}")
# SIGSEGV handler
def SIGSEGV_signal_arises(signalNum, stack):
logger.critical(f"{signalNum} : SIGSEGV arises")
logger.critical(f"Stack trace: {stack}")
signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises)
from loader import initialize_models
from engine import Engine, base64_data_uri_to_PIL_Image, create_engine
# Global constants
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")
image_cache: Dict[str, Image.Image] = {}
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
"""
Handle WebSocket connections for the FacePoke application.
Args:
request (web.Request): The incoming request object.
Returns:
web.WebSocketResponse: The WebSocket response object.
"""
ws = web.WebSocketResponse()
await ws.prepare(request)
session: Optional[FacePokeSession] = None
try:
logger.info("New WebSocket connection established")
while True:
msg = await ws.receive()
if msg.type == WSMsgType.TEXT:
data = json.loads(msg.data)
# let's not log user requests, they are heavy
#logger.debug(f"Received message: {data}")
if data['type'] == 'modify_image':
uuid = data.get('uuid')
if not uuid:
logger.warning("Received message without UUID")
await handle_modify_image(request, ws, data, uuid)
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning(f"WebSocket connection closed: {msg.type}")
break
except Exception as e:
logger.error(f"Error in websocket_handler: {str(e)}")
logger.exception("Full traceback:")
finally:
if session:
await session.stop()
del active_sessions[session.session_id]
logger.info("WebSocket connection closed")
return ws
async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str):
"""
Handle the 'modify_image' request.
Args:
request (web.Request): The incoming request object.
ws (web.WebSocketResponse): The WebSocket response object.
msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters.
uuid: A unique identifier for the request.
"""
logger.info("Received modify_image request")
try:
engine = request.app['engine']
image_hash = msg.get('image_hash')
if image_hash:
image_or_hash = image_hash
else:
image_data = msg['image']
image_or_hash = image_data
modified_image_base64 = await engine.modify_image(image_or_hash, msg['params'])
await ws.send_json({
"type": "modified_image",
"image": modified_image_base64,
"image_hash": engine.get_image_hash(image_or_hash),
"success": True,
"uuid": uuid # Include the UUID in the response
})
logger.info("Successfully sent modified image")
except Exception as e:
logger.error(f"Error in modify_image: {str(e)}")
await ws.send_json({
"type": "modified_image",
"success": False,
"error": str(e),
"uuid": uuid # Include the UUID even in error responses
})
async def index(request: web.Request) -> web.Response:
"""Serve the index.html file"""
content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read()
return web.Response(content_type="text/html", text=content)
async def js_index(request: web.Request) -> web.Response:
"""Serve the index.js file"""
content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read()
return web.Response(content_type="application/javascript", text=content)
async def hf_logo(request: web.Request) -> web.Response:
"""Serve the hf-logo.svg file"""
content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read()
return web.Response(content_type="image/svg+xml", text=content)
async def on_shutdown(app: web.Application):
"""Cleanup function to be called on server shutdown."""
logger.info("Server shutdown initiated, cleaning up resources...")
for session in list(active_sessions.values()):
await session.stop()
active_sessions.clear()
logger.info("All active sessions have been closed")
if 'engine' in app:
await app['engine'].cleanup()
logger.info("Engine instance cleaned up")
logger.info("Server shutdown complete")
async def initialize_app() -> web.Application:
"""Initialize and configure the web application."""
try:
logger.info("Initializing application...")
models = await initialize_models()
logger.info("π Creating Engine instance...")
engine = create_engine(models)
logger.info("β
Engine instance created.")
app = web.Application()
app['engine'] = engine
app.on_shutdown.append(on_shutdown)
# Configure routes
app.router.add_get("/", index)
app.router.add_get("/index.js", js_index)
app.router.add_get("/hf-logo.svg", hf_logo)
app.router.add_get("/ws", websocket_handler)
logger.info("Application routes configured")
return app
except Exception as e:
logger.error(f"π¨ Error during application initialization: {str(e)}")
logger.exception("Full traceback:")
raise
async def start_background_tasks(app: web.Application):
"""
Start background tasks for the application.
Args:
app (web.Application): The web application instance.
"""
app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app))
async def cleanup_background_tasks(app: web.Application):
"""
Clean up background tasks when the application is shutting down.
Args:
app (web.Application): The web application instance.
"""
app['cleanup_task'].cancel()
await app['cleanup_task']
async def periodic_cleanup(app: web.Application):
"""
Perform periodic cleanup tasks for the application.
Args:
app (web.Application): The web application instance.
"""
while True:
try:
await asyncio.sleep(3600) # Run cleanup every hour
await cleanup_inactive_sessions(app)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in periodic cleanup: {str(e)}")
logger.exception("Full traceback:")
async def cleanup_inactive_sessions(app: web.Application):
"""
Clean up inactive sessions.
Args:
app (web.Application): The web application instance.
"""
logger.info("Starting cleanup of inactive sessions")
inactive_sessions = [
session_id for session_id, session in active_sessions.items()
if not session.is_running.is_set()
]
for session_id in inactive_sessions:
session = active_sessions.pop(session_id)
await session.stop()
logger.info(f"Cleaned up inactive session: {session_id}")
logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions")
def main():
"""
Main function to start the FacePoke application.
"""
try:
logger.info("Starting FacePoke application")
app = asyncio.run(initialize_app())
app.on_startup.append(start_background_tasks)
app.on_cleanup.append(cleanup_background_tasks)
logger.info("Application initialized, starting web server")
web.run_app(app, host="0.0.0.0", port=8080)
except Exception as e:
logger.critical(f"π¨ FATAL: Failed to start the app: {str(e)}")
logger.exception("Full traceback:")
if __name__ == "__main__":
main()
|