Spaces:
Sleeping
Sleeping
File size: 3,745 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import json
import os
import subprocess
import sys
import time
import aiohttp
import requests
from lagent.schema import AgentMessage
class HTTPAgentClient:
def __init__(self, host='127.0.0.1', port=8090, timeout=None):
self.host = host
self.port = port
self.timeout = timeout
@property
def is_alive(self):
try:
resp = requests.get(
f'http://{self.host}:{self.port}/health_check',
timeout=self.timeout)
return resp.status_code == 200
except:
return False
def __call__(self, *message, session_id: int = 0, **kwargs):
response = requests.post(
f'http://{self.host}:{self.port}/chat_completion',
json={
'message': [
m if isinstance(m, str) else m.model_dump()
for m in message
],
'session_id': session_id,
**kwargs,
},
headers={'Content-Type': 'application/json'},
timeout=self.timeout)
resp = response.json()
if response.status_code != 200:
return resp
return AgentMessage.model_validate(resp)
def state_dict(self, session_id: int = 0):
resp = requests.get(
f'http://{self.host}:{self.port}/memory/{session_id}',
timeout=self.timeout)
return resp.json()
class HTTPAgentServer(HTTPAgentClient):
def __init__(self, gpu_id, config, host='127.0.0.1', port=8090):
super().__init__(host, port)
self.gpu_id = gpu_id
self.config = config
self.start_server()
def start_server(self):
# set CUDA_VISIBLE_DEVICES in subprocess
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = self.gpu_id
cmds = [
sys.executable, 'lagent/distributed/http_serve/app.py', '--host',
self.host, '--port',
str(self.port), '--config',
json.dumps(self.config)
]
self.process = subprocess.Popen(
cmds,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
while True:
output = self.process.stdout.readline()
if not output: # 如果读到 EOF,跳出循环
break
sys.stdout.write(output) # 打印到标准输出
sys.stdout.flush()
if 'Uvicorn running on' in output: # 根据实际输出调整
break
time.sleep(0.1)
def shutdown(self):
self.process.terminate()
self.process.wait()
class AsyncHTTPAgentMixin:
async def __call__(self, *message, session_id: int = 0, **kwargs):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(self.timeout)) as session:
async with session.post(
f'http://{self.host}:{self.port}/chat_completion',
json={
'message': [
m if isinstance(m, str) else m.model_dump()
for m in message
],
'session_id': session_id,
**kwargs,
},
headers={'Content-Type': 'application/json'},
) as response:
resp = await response.json()
if response.status != 200:
return resp
return AgentMessage.model_validate(resp)
class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient):
pass
class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer):
pass
|