File size: 4,021 Bytes
7781557 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from typing import Dict
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage
from jose import JWTError, jwt
import json
from .auth import SECRET_KEY, ALGORITHM
from .db.database import get_user_by_username
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.llm = ChatOpenAI(model="gpt-4o-mini")
self.chains = {}
async def connect(self, websocket: WebSocket, username: str):
# Remove the websocket.accept() from here since it's called in handle_websocket
self.active_connections[username] = websocket
self.chains[username] = self.llm
# Send confirmation of successful connection
await websocket.send_json({
"type": "connection_established",
"message": f"Connected as {username}"
})
def disconnect(self, username: str):
self.active_connections.pop(username, None)
self.chains[username] = None
async def send_message(self, message: str, username: str):
if username in self.active_connections:
websocket = self.active_connections[username]
try:
chain = self.chains[username]
astream = chain.astream(message)
async for chunk in astream:
if isinstance(chunk, AIMessage):
await websocket.send_json({
"type": "message",
"message": chunk.content,
"sender": "ai"
})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
manager = ConnectionManager()
async def handle_websocket(websocket: WebSocket):
await websocket.accept() # Accept the connection once
username = None
try:
# Wait for authentication message
auth_message = await websocket.receive_text()
try:
# Try to parse as JSON first
try:
data = json.loads(auth_message)
token = data.get('token')
except json.JSONDecodeError:
# If not JSON, treat as raw token
token = auth_message
# Verify token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if not username:
await websocket.close(code=1008)
return
# Get user from database
user = await get_user_by_username(username)
if not user:
await websocket.close(code=1008)
return
# Connect user
await manager.connect(websocket, username)
# Main message loop
while True:
message = await websocket.receive_text()
try:
data = json.loads(message)
if data.get('type') == 'message':
await manager.send_message(data.get('content', ''), username)
except json.JSONDecodeError:
# Handle plain text messages
await manager.send_message(message, username)
except JWTError:
await websocket.send_json({
"type": "error",
"message": "Authentication failed"
})
await websocket.close(code=1008)
except WebSocketDisconnect:
if username:
manager.disconnect(username)
except Exception as e:
print(f"WebSocket error: {str(e)}")
if username:
manager.disconnect(username)
try:
await websocket.close(code=1011)
except:
pass |