File size: 8,616 Bytes
d69879c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
FacePoke API

Author: Julian Bilcke
Date: September 30, 2024
"""

import sys
import asyncio
import hashlib
from aiohttp import web, WSMsgType
import json
import uuid
import logging
import os
import zipfile
import signal
from typing import Dict, Any, List, Optional
import base64
import io
from PIL import Image
import numpy as np

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set asyncio logger to DEBUG level
logging.getLogger("asyncio").setLevel(logging.DEBUG)

logger.debug(f"Python version: {sys.version}")

# SIGSEGV handler
def SIGSEGV_signal_arises(signalNum, stack):
    logger.critical(f"{signalNum} : SIGSEGV arises")
    logger.critical(f"Stack trace: {stack}")

signal.signal(signal.SIGSEGV, SIGSEGV_signal_arises)

from loader import initialize_models
from engine import Engine, base64_data_uri_to_PIL_Image, create_engine

# Global constants
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")

image_cache: Dict[str, Image.Image] = {}

async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
    """
    Handle WebSocket connections for the FacePoke application.

    Args:
        request (web.Request): The incoming request object.

    Returns:
        web.WebSocketResponse: The WebSocket response object.
    """
    ws = web.WebSocketResponse()
    await ws.prepare(request)

    session: Optional[FacePokeSession] = None
    try:
        logger.info("New WebSocket connection established")

        while True:
            msg = await ws.receive()

            if msg.type == WSMsgType.TEXT:
                data = json.loads(msg.data)

                # let's not log user requests, they are heavy
                #logger.debug(f"Received message: {data}")

                if data['type'] == 'modify_image':
                    uuid = data.get('uuid')
                    if not uuid:
                        logger.warning("Received message without UUID")

                    await handle_modify_image(request, ws, data, uuid)


            elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
                logger.warning(f"WebSocket connection closed: {msg.type}")
                break

    except Exception as e:
        logger.error(f"Error in websocket_handler: {str(e)}")
        logger.exception("Full traceback:")
    finally:
        if session:
            await session.stop()
            del active_sessions[session.session_id]
        logger.info("WebSocket connection closed")
    return ws

async def handle_modify_image(request: web.Request, ws: web.WebSocketResponse, msg: Dict[str, Any], uuid: str):
    """
    Handle the 'modify_image' request.

    Args:
        request (web.Request): The incoming request object.
        ws (web.WebSocketResponse): The WebSocket response object.
        msg (Dict[str, Any]): The message containing the image or image_hash and modification parameters.
        uuid: A unique identifier for the request.
    """
    logger.info("Received modify_image request")
    try:
        engine = request.app['engine']
        image_hash = msg.get('image_hash')

        if image_hash:
            image_or_hash = image_hash
        else:
            image_data = msg['image']
            image_or_hash = image_data

        modified_image_base64 = await engine.modify_image(image_or_hash, msg['params'])

        await ws.send_json({
            "type": "modified_image",
            "image": modified_image_base64,
            "image_hash": engine.get_image_hash(image_or_hash),
            "success": True,
            "uuid": uuid  # Include the UUID in the response
        })
        logger.info("Successfully sent modified image")
    except Exception as e:
        logger.error(f"Error in modify_image: {str(e)}")
        await ws.send_json({
            "type": "modified_image",
            "success": False,
            "error": str(e),
            "uuid": uuid  # Include the UUID even in error responses
        })

async def index(request: web.Request) -> web.Response:
    """Serve the index.html file"""
    content = open(os.path.join(os.path.dirname(__file__), "public", "index.html"), "r").read()
    return web.Response(content_type="text/html", text=content)

async def js_index(request: web.Request) -> web.Response:
    """Serve the index.js file"""
    content = open(os.path.join(os.path.dirname(__file__), "public", "index.js"), "r").read()
    return web.Response(content_type="application/javascript", text=content)

async def hf_logo(request: web.Request) -> web.Response:
    """Serve the hf-logo.svg file"""
    content = open(os.path.join(os.path.dirname(__file__), "public", "hf-logo.svg"), "r").read()
    return web.Response(content_type="image/svg+xml", text=content)

async def on_shutdown(app: web.Application):
    """Cleanup function to be called on server shutdown."""
    logger.info("Server shutdown initiated, cleaning up resources...")
    for session in list(active_sessions.values()):
        await session.stop()
    active_sessions.clear()
    logger.info("All active sessions have been closed")

    if 'engine' in app:
        await app['engine'].cleanup()
        logger.info("Engine instance cleaned up")

    logger.info("Server shutdown complete")

async def initialize_app() -> web.Application:
    """Initialize and configure the web application."""
    try:
        logger.info("Initializing application...")
        models = await initialize_models()
        logger.info("πŸš€ Creating Engine instance...")
        engine = create_engine(models)
        logger.info("βœ… Engine instance created.")

        app = web.Application()
        app['engine'] = engine

        app.on_shutdown.append(on_shutdown)

        # Configure routes
        app.router.add_get("/", index)
        app.router.add_get("/index.js", js_index)
        app.router.add_get("/hf-logo.svg", hf_logo)
        app.router.add_get("/ws", websocket_handler)

        logger.info("Application routes configured")

        return app
    except Exception as e:
        logger.error(f"🚨 Error during application initialization: {str(e)}")
        logger.exception("Full traceback:")
        raise

async def start_background_tasks(app: web.Application):
    """
    Start background tasks for the application.

    Args:
        app (web.Application): The web application instance.
    """
    app['cleanup_task'] = asyncio.create_task(periodic_cleanup(app))

async def cleanup_background_tasks(app: web.Application):
    """
    Clean up background tasks when the application is shutting down.

    Args:
        app (web.Application): The web application instance.
    """
    app['cleanup_task'].cancel()
    await app['cleanup_task']

async def periodic_cleanup(app: web.Application):
    """
    Perform periodic cleanup tasks for the application.

    Args:
        app (web.Application): The web application instance.
    """
    while True:
        try:
            await asyncio.sleep(3600)  # Run cleanup every hour
            await cleanup_inactive_sessions(app)
        except asyncio.CancelledError:
            break
        except Exception as e:
            logger.error(f"Error in periodic cleanup: {str(e)}")
            logger.exception("Full traceback:")

async def cleanup_inactive_sessions(app: web.Application):
    """
    Clean up inactive sessions.

    Args:
        app (web.Application): The web application instance.
    """
    logger.info("Starting cleanup of inactive sessions")
    inactive_sessions = [
        session_id for session_id, session in active_sessions.items()
        if not session.is_running.is_set()
    ]
    for session_id in inactive_sessions:
        session = active_sessions.pop(session_id)
        await session.stop()
        logger.info(f"Cleaned up inactive session: {session_id}")
    logger.info(f"Cleaned up {len(inactive_sessions)} inactive sessions")

def main():
    """
    Main function to start the FacePoke application.
    """
    try:
        logger.info("Starting FacePoke application")
        app = asyncio.run(initialize_app())
        app.on_startup.append(start_background_tasks)
        app.on_cleanup.append(cleanup_background_tasks)
        logger.info("Application initialized, starting web server")
        web.run_app(app, host="0.0.0.0", port=8080)
    except Exception as e:
        logger.critical(f"🚨 FATAL: Failed to start the app: {str(e)}")
        logger.exception("Full traceback:")

if __name__ == "__main__":
    main()