Spaces:
Sleeping
Sleeping
import json | |
import os | |
import subprocess | |
import sys | |
import time | |
import aiohttp | |
import requests | |
from lagent.schema import AgentMessage | |
class HTTPAgentClient: | |
def __init__(self, host='127.0.0.1', port=8090, timeout=None): | |
self.host = host | |
self.port = port | |
self.timeout = timeout | |
def is_alive(self): | |
try: | |
resp = requests.get( | |
f'http://{self.host}:{self.port}/health_check', | |
timeout=self.timeout) | |
return resp.status_code == 200 | |
except: | |
return False | |
def __call__(self, *message, session_id: int = 0, **kwargs): | |
response = requests.post( | |
f'http://{self.host}:{self.port}/chat_completion', | |
json={ | |
'message': [ | |
m if isinstance(m, str) else m.model_dump() | |
for m in message | |
], | |
'session_id': session_id, | |
**kwargs, | |
}, | |
headers={'Content-Type': 'application/json'}, | |
timeout=self.timeout) | |
resp = response.json() | |
if response.status_code != 200: | |
return resp | |
return AgentMessage.model_validate(resp) | |
def state_dict(self, session_id: int = 0): | |
resp = requests.get( | |
f'http://{self.host}:{self.port}/memory/{session_id}', | |
timeout=self.timeout) | |
return resp.json() | |
class HTTPAgentServer(HTTPAgentClient): | |
def __init__(self, gpu_id, config, host='127.0.0.1', port=8090): | |
super().__init__(host, port) | |
self.gpu_id = gpu_id | |
self.config = config | |
self.start_server() | |
def start_server(self): | |
# set CUDA_VISIBLE_DEVICES in subprocess | |
env = os.environ.copy() | |
env['CUDA_VISIBLE_DEVICES'] = self.gpu_id | |
cmds = [ | |
sys.executable, 'lagent/distributed/http_serve/app.py', '--host', | |
self.host, '--port', | |
str(self.port), '--config', | |
json.dumps(self.config) | |
] | |
self.process = subprocess.Popen( | |
cmds, | |
env=env, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
text=True) | |
while True: | |
output = self.process.stdout.readline() | |
if not output: # 如果读到 EOF,跳出循环 | |
break | |
sys.stdout.write(output) # 打印到标准输出 | |
sys.stdout.flush() | |
if 'Uvicorn running on' in output: # 根据实际输出调整 | |
break | |
time.sleep(0.1) | |
def shutdown(self): | |
self.process.terminate() | |
self.process.wait() | |
class AsyncHTTPAgentMixin: | |
async def __call__(self, *message, session_id: int = 0, **kwargs): | |
async with aiohttp.ClientSession( | |
timeout=aiohttp.ClientTimeout(self.timeout)) as session: | |
async with session.post( | |
f'http://{self.host}:{self.port}/chat_completion', | |
json={ | |
'message': [ | |
m if isinstance(m, str) else m.model_dump() | |
for m in message | |
], | |
'session_id': session_id, | |
**kwargs, | |
}, | |
headers={'Content-Type': 'application/json'}, | |
) as response: | |
resp = await response.json() | |
if response.status != 200: | |
return resp | |
return AgentMessage.model_validate(resp) | |
class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient): | |
pass | |
class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer): | |
pass | |