File size: 3,213 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import json
import logging
import time

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request

from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string


class AgentAPIServer:

    def __init__(self,
                 config: dict,
                 host: str = '127.0.0.1',
                 port: int = 8090):
        self.app = FastAPI(docs_url='/')
        self.app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        cls_name = config.pop('type')
        python_path = config.pop('python_path', None)
        cls_name = load_class_from_string(cls_name, python_path) if isinstance(
            cls_name, str) else cls_name
        self.agent = cls_name(**config)
        self.setup_routes()
        self.run(host, port)

    def setup_routes(self):

        def heartbeat():
            return {'status': 'success', 'timestamp': time.time()}

        async def process_message(request: Request):
            try:
                body = await request.json()
                message = [
                    m if isinstance(m, str) else AgentMessage.model_validate(m)
                    for m in body.pop('message')
                ]
                result = await self.agent(*message, **body)
                return result
            except Exception as e:
                logging.error(f'Error processing message: {str(e)}')
                raise HTTPException(
                    status_code=500, detail='Internal Server Error')

        def get_memory(session_id: int = 0):
            try:
                result = self.agent.state_dict(session_id)
                return result
            except KeyError:
                raise HTTPException(
                    status_code=404, detail="Session ID not found")
            except Exception as e:
                logging.error(f'Error processing message: {str(e)}')
                raise HTTPException(
                    status_code=500, detail='Internal Server Error')

        self.app.add_api_route('/health_check', heartbeat, methods=['GET'])
        self.app.add_api_route(
            '/chat_completion', process_message, methods=['POST'])
        self.app.add_api_route(
            '/memory/{session_id}', get_memory, methods=['GET'])

    def run(self, host='127.0.0.1', port=8090):
        logging.info(f'Starting server at {host}:{port}')
        uvicorn.run(self.app, host=host, port=port)


def parse_args():
    parser = argparse.ArgumentParser(description='Async Agent API Server')
    parser.add_argument('--host', type=str, default='127.0.0.1')
    parser.add_argument('--port', type=int, default=8090)
    parser.add_argument(
        '--config',
        type=json.loads,
        required=True,
        help='JSON configuration for the agent')
    args = parser.parse_args()

    return args


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    AgentAPIServer(args.config, host=args.host, port=args.port)