Spaces:
Sleeping
Sleeping
File size: 3,213 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import argparse
import json
import logging
import time
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from lagent.schema import AgentMessage
from lagent.utils import load_class_from_string
class AgentAPIServer:
def __init__(self,
config: dict,
host: str = '127.0.0.1',
port: int = 8090):
self.app = FastAPI(docs_url='/')
self.app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
cls_name = config.pop('type')
python_path = config.pop('python_path', None)
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
cls_name, str) else cls_name
self.agent = cls_name(**config)
self.setup_routes()
self.run(host, port)
def setup_routes(self):
def heartbeat():
return {'status': 'success', 'timestamp': time.time()}
async def process_message(request: Request):
try:
body = await request.json()
message = [
m if isinstance(m, str) else AgentMessage.model_validate(m)
for m in body.pop('message')
]
result = await self.agent(*message, **body)
return result
except Exception as e:
logging.error(f'Error processing message: {str(e)}')
raise HTTPException(
status_code=500, detail='Internal Server Error')
def get_memory(session_id: int = 0):
try:
result = self.agent.state_dict(session_id)
return result
except KeyError:
raise HTTPException(
status_code=404, detail="Session ID not found")
except Exception as e:
logging.error(f'Error processing message: {str(e)}')
raise HTTPException(
status_code=500, detail='Internal Server Error')
self.app.add_api_route('/health_check', heartbeat, methods=['GET'])
self.app.add_api_route(
'/chat_completion', process_message, methods=['POST'])
self.app.add_api_route(
'/memory/{session_id}', get_memory, methods=['GET'])
def run(self, host='127.0.0.1', port=8090):
logging.info(f'Starting server at {host}:{port}')
uvicorn.run(self.app, host=host, port=port)
def parse_args():
parser = argparse.ArgumentParser(description='Async Agent API Server')
parser.add_argument('--host', type=str, default='127.0.0.1')
parser.add_argument('--port', type=int, default=8090)
parser.add_argument(
'--config',
type=json.loads,
required=True,
help='JSON configuration for the agent')
args = parser.parse_args()
return args
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
AgentAPIServer(args.config, host=args.host, port=args.port)
|