File size: 4,021 Bytes
7781557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from typing import Dict
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage
from jose import JWTError, jwt
import json
from .auth import SECRET_KEY, ALGORITHM
from .db.database import get_user_by_username

class ConnectionManager:
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.llm = ChatOpenAI(model="gpt-4o-mini")
        self.chains = {}

    async def connect(self, websocket: WebSocket, username: str):
        # Remove the websocket.accept() from here since it's called in handle_websocket
        self.active_connections[username] = websocket
        self.chains[username] = self.llm
        # Send confirmation of successful connection
        await websocket.send_json({
            "type": "connection_established",
            "message": f"Connected as {username}"
        })

    def disconnect(self, username: str):
        self.active_connections.pop(username, None)
        self.chains[username] = None

    async def send_message(self, message: str, username: str):
 
        if username in self.active_connections:
          
            websocket = self.active_connections[username]
            try:
                chain = self.chains[username]
                astream = chain.astream(message)
                async for chunk in astream:
                    if isinstance(chunk, AIMessage):
                        await websocket.send_json({
                            "type": "message",
                            "message": chunk.content,
                                "sender": "ai"
                            })
               
             
            except Exception as e:
                await websocket.send_json({
                    "type": "error",
                    "message": str(e)
                })

manager = ConnectionManager()

async def handle_websocket(websocket: WebSocket):
    await websocket.accept()  # Accept the connection once
    username = None
    
    try:
        # Wait for authentication message
        auth_message = await websocket.receive_text()
        try:
            # Try to parse as JSON first
            try:
                data = json.loads(auth_message)
                token = data.get('token')
            except json.JSONDecodeError:
                # If not JSON, treat as raw token
                token = auth_message

            # Verify token
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username = payload.get("sub")
            
            if not username:
                await websocket.close(code=1008)
                return

            # Get user from database
            user = await get_user_by_username(username)
            if not user:
                await websocket.close(code=1008)
                return

            # Connect user
            await manager.connect(websocket, username)
            
            # Main message loop
            while True:
                message = await websocket.receive_text()
                try:
                    data = json.loads(message)
                    if data.get('type') == 'message':
                        await manager.send_message(data.get('content', ''), username)
                except json.JSONDecodeError:
                    # Handle plain text messages
                    await manager.send_message(message, username)

        except JWTError:
            await websocket.send_json({
                "type": "error",
                "message": "Authentication failed"
            })
            await websocket.close(code=1008)
            
    except WebSocketDisconnect:
        if username:
            manager.disconnect(username)
    except Exception as e:
        print(f"WebSocket error: {str(e)}")
        if username:
            manager.disconnect(username)
        try:
            await websocket.close(code=1011)
        except:
            pass