File size: 3,745 Bytes
e679d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import os
import subprocess
import sys
import time

import aiohttp
import requests

from lagent.schema import AgentMessage


class HTTPAgentClient:

    def __init__(self, host='127.0.0.1', port=8090, timeout=None):
        self.host = host
        self.port = port
        self.timeout = timeout

    @property
    def is_alive(self):
        try:
            resp = requests.get(
                f'http://{self.host}:{self.port}/health_check',
                timeout=self.timeout)
            return resp.status_code == 200
        except:
            return False

    def __call__(self, *message, session_id: int = 0, **kwargs):
        response = requests.post(
            f'http://{self.host}:{self.port}/chat_completion',
            json={
                'message': [
                    m if isinstance(m, str) else m.model_dump()
                    for m in message
                ],
                'session_id': session_id,
                **kwargs,
            },
            headers={'Content-Type': 'application/json'},
            timeout=self.timeout)
        resp = response.json()
        if response.status_code != 200:
            return resp
        return AgentMessage.model_validate(resp)

    def state_dict(self, session_id: int = 0):
        resp = requests.get(
            f'http://{self.host}:{self.port}/memory/{session_id}',
            timeout=self.timeout)
        return resp.json()


class HTTPAgentServer(HTTPAgentClient):

    def __init__(self, gpu_id, config, host='127.0.0.1', port=8090):
        super().__init__(host, port)
        self.gpu_id = gpu_id
        self.config = config
        self.start_server()

    def start_server(self):
        # set CUDA_VISIBLE_DEVICES in subprocess
        env = os.environ.copy()
        env['CUDA_VISIBLE_DEVICES'] = self.gpu_id
        cmds = [
            sys.executable, 'lagent/distributed/http_serve/app.py', '--host',
            self.host, '--port',
            str(self.port), '--config',
            json.dumps(self.config)
        ]
        self.process = subprocess.Popen(
            cmds,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True)

        while True:
            output = self.process.stdout.readline()
            if not output:  # 如果读到 EOF,跳出循环
                break
            sys.stdout.write(output)  # 打印到标准输出
            sys.stdout.flush()
            if 'Uvicorn running on' in output:  # 根据实际输出调整
                break
            time.sleep(0.1)

    def shutdown(self):
        self.process.terminate()
        self.process.wait()


class AsyncHTTPAgentMixin:

    async def __call__(self, *message, session_id: int = 0, **kwargs):
        async with aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(self.timeout)) as session:
            async with session.post(
                    f'http://{self.host}:{self.port}/chat_completion',
                    json={
                        'message': [
                            m if isinstance(m, str) else m.model_dump()
                            for m in message
                        ],
                        'session_id': session_id,
                        **kwargs,
                    },
                    headers={'Content-Type': 'application/json'},
            ) as response:
                resp = await response.json()
                if response.status != 200:
                    return resp
                return AgentMessage.model_validate(resp)


class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient):
    pass


class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer):
    pass