Spaces:
Paused
Paused
""" | |
FacePoke API | |
Author: Julian Bilcke | |
Date: September 30, 2024 | |
""" | |
import sys | |
import asyncio | |
import hashlib | |
from aiohttp import web, WSMsgType | |
import json | |
import uuid | |
import logging | |
import os | |
import zipfile | |
import signal | |
from typing import Dict, Any, List, Optional | |
import base64 | |
import io | |
from PIL import Image | |
import numpy as np | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Set asyncio logger to DEBUG level | |
logging.getLogger("asyncio").setLevel(logging.DEBUG) | |
logger.debug(f"Python version: {sys.version}") | |
# SIGSEGV handler | |
def SIGSEGV_signal_arises(signalNum, stack): | |
logger.critical(f"{signalNum} : SIGSEGV arises") | |
logger.critical(f"Stack trace: {stack}") | |
signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises) | |
from loader import initialize_models | |
from engine import Engine, base64_data_uri_to_PIL_Image, create_engine | |
# Global constants | |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
image_cache: Dict[str, Image.Image] = {} | |
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: | |
""" | |
Handle WebSocket connections for the FacePoke application. | |
Args: | |
request (web.Request): The incoming request object. | |
Returns: | |
web.WebSocketResponse: The WebSocket response object. | |
""" | |
ws = web.WebSocketResponse() | |
await ws.prepare(request) | |
session: Optional[FacePokeSession] = None | |
try: | |
logger.info("New WebSocket connection established") | |
while True: | |
msg = await ws.receive() | |
if msg.type == WSMsgType.TEXT: | |
data = json.loads(msg.data) | |
# let's not log user requests, they are heavy | |
#logger.debug(f"Received message: {data}") | |
if data['type'] == 'modify_image': | |
uuid = data.get('uuid') | |
if not uuid: | |
logger.warning("Received message without UUID") | |
await handle_modify_image(request, ws, data, uuid) | |
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): | |
logger.warning(f"WebSocket connection closed: {msg.type}") | |
break | |
except Exception as e: | |
logger.error(f"Error in websocket_handler: {str(e)}") | |
logger.exception("Full traceback:") | |
finally: | |
if session: | |
await session.stop() | |
del active_sessions[session.session_id] | |
logger.info("WebSocket connection closed") | |
return ws | |
async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str): | |
""" | |
Handle the 'modify_image' request. | |
Args: | |
request (web.Request): The incoming request object. | |
ws (web.WebSocketResponse): The WebSocket response object. | |
msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters. | |
uuid: A unique identifier for the request. | |
""" | |
logger.info("Received modify_image request") | |
try: | |
engine = request.app['engine'] | |
image_hash = msg.get('image_hash') | |
if image_hash: | |
image_or_hash = image_hash | |
else: | |
image_data = msg['image'] | |
image_or_hash = image_data | |
modified_image_base64 = await engine.modify_image(image_or_hash, msg['params']) | |
await ws.send_json({ | |
"type": "modified_image", | |
"image": modified_image_base64, | |
"image_hash": engine.get_image_hash(image_or_hash), | |
"success": True, | |
"uuid": uuid # Include the UUID in the response | |
}) | |
logger.info("Successfully sent modified image") | |
except Exception as e: | |
logger.error(f"Error in modify_image: {str(e)}") | |
await ws.send_json({ | |
"type": "modified_image", | |
"success": False, | |
"error": str(e), | |
"uuid": uuid # Include the UUID even in error responses | |
}) | |
async def index(request: web.Request) -> web.Response: | |
"""Serve the index.html file""" | |
content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read() | |
return web.Response(content_type="text/html", text=content) | |
async def js_index(request: web.Request) -> web.Response: | |
"""Serve the index.js file""" | |
content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read() | |
return web.Response(content_type="application/javascript", text=content) | |
async def hf_logo(request: web.Request) -> web.Response: | |
"""Serve the hf-logo.svg file""" | |
content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read() | |
return web.Response(content_type="image/svg+xml", text=content) | |
async def on_shutdown(app: web.Application): | |
"""Cleanup function to be called on server shutdown.""" | |
logger.info("Server shutdown initiated, cleaning up resources...") | |
for session in list(active_sessions.values()): | |
await session.stop() | |
active_sessions.clear() | |
logger.info("All active sessions have been closed") | |
if 'engine' in app: | |
await app['engine'].cleanup() | |
logger.info("Engine instance cleaned up") | |
logger.info("Server shutdown complete") | |
async def initialize_app() -> web.Application: | |
"""Initialize and configure the web application.""" | |
try: | |
logger.info("Initializing application...") | |
models = await initialize_models() | |
logger.info("π Creating Engine instance...") | |
engine = create_engine(models) | |
logger.info("β Engine instance created.") | |
app = web.Application() | |
app['engine'] = engine | |
app.on_shutdown.append(on_shutdown) | |
# Configure routes | |
app.router.add_get("/", index) | |
app.router.add_get("/index.js", js_index) | |
app.router.add_get("/hf-logo.svg", hf_logo) | |
app.router.add_get("/ws", websocket_handler) | |
logger.info("Application routes configured") | |
return app | |
except Exception as e: | |
logger.error(f"π¨ Error during application initialization: {str(e)}") | |
logger.exception("Full traceback:") | |
raise | |
async def start_background_tasks(app: web.Application): | |
""" | |
Start background tasks for the application. | |
Args: | |
app (web.Application): The web application instance. | |
""" | |
app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app)) | |
async def cleanup_background_tasks(app: web.Application): | |
""" | |
Clean up background tasks when the application is shutting down. | |
Args: | |
app (web.Application): The web application instance. | |
""" | |
app['cleanup_task'].cancel() | |
await app['cleanup_task'] | |
async def periodic_cleanup(app: web.Application): | |
""" | |
Perform periodic cleanup tasks for the application. | |
Args: | |
app (web.Application): The web application instance. | |
""" | |
while True: | |
try: | |
await asyncio.sleep(3600) # Run cleanup every hour | |
await cleanup_inactive_sessions(app) | |
except asyncio.CancelledError: | |
break | |
except Exception as e: | |
logger.error(f"Error in periodic cleanup: {str(e)}") | |
logger.exception("Full traceback:") | |
async def cleanup_inactive_sessions(app: web.Application): | |
""" | |
Clean up inactive sessions. | |
Args: | |
app (web.Application): The web application instance. | |
""" | |
logger.info("Starting cleanup of inactive sessions") | |
inactive_sessions = [ | |
session_id for session_id, session in active_sessions.items() | |
if not session.is_running.is_set() | |
] | |
for session_id in inactive_sessions: | |
session = active_sessions.pop(session_id) | |
await session.stop() | |
logger.info(f"Cleaned up inactive session: {session_id}") | |
logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions") | |
def main(): | |
""" | |
Main function to start the FacePoke application. | |
""" | |
try: | |
logger.info("Starting FacePoke application") | |
app = asyncio.run(initialize_app()) | |
app.on_startup.append(start_background_tasks) | |
app.on_cleanup.append(cleanup_background_tasks) | |
logger.info("Application initialized, starting web server") | |
web.run_app(app, host="0.0.0.0", port=8080) | |
except Exception as e: | |
logger.critical(f"π¨ FATAL: Failed to start the app: {str(e)}") | |
logger.exception("Full traceback:") | |
if __name__ == "__main__": | |
main() | |