File size: 6,671 Bytes
f3f614f
 
 
dc9e27a
f3f614f
 
 
 
 
dc9e27a
 
f3f614f
 
 
 
 
 
 
dc9e27a
 
 
 
 
 
 
 
f3f614f
 
 
 
dc9e27a
 
 
 
 
 
 
 
f3f614f
 
 
 
dc9e27a
f3f614f
 
 
dc9e27a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f614f
 
 
dc9e27a
f3f614f
dc9e27a
f3f614f
 
dc9e27a
f3f614f
 
dc9e27a
f3f614f
dc9e27a
f3f614f
 
 
 
 
 
 
dc9e27a
f3f614f
 
dc9e27a
 
 
 
 
 
 
 
 
 
f3f614f
dc9e27a
f3f614f
 
dc9e27a
 
 
f3f614f
dc9e27a
f3f614f
 
dc9e27a
f3f614f
 
dc9e27a
 
 
 
 
 
 
 
f3f614f
dc9e27a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3f614f
dc9e27a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import asyncio
import json
import logging
import random
from typing import Dict, List, Union

import janus
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse

from mindsearch.agent import init_agent


def parse_arguments():
    import argparse

    parser = argparse.ArgumentParser(description="MindSearch API")
    parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host")
    parser.add_argument("--port", default=8002, type=int, help="Service port")
    parser.add_argument("--lang", default="cn", type=str, help="Language")
    parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format")
    parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine")
    parser.add_argument("--asy", default=False, action="store_true", help="Agent mode")
    return parser.parse_args()


args = parse_arguments()
app = FastAPI(docs_url="/")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class GenerationParams(BaseModel):
    inputs: Union[str, List[Dict]]
    session_id: int = Field(default_factory=lambda: random.randint(0, 999999))
    agent_cfg: Dict = dict()


def _postprocess_agent_message(message: dict) -> dict:
    content, fmt = message["content"], message["formatted"]
    current_node = content["current_node"] if isinstance(content, dict) else None
    if current_node:
        message["content"] = None
        for key in ["ref2url"]:
            fmt.pop(key, None)
        graph = fmt["node"]
        for key in graph.copy():
            if key != current_node:
                graph.pop(key)
        if current_node not in ["root", "response"]:
            node = graph[current_node]
            for key in ["memory", "session_id"]:
                node.pop(key, None)
            node_fmt = node["response"]["formatted"]
            if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt:
                node["response"]["content"] = None
                node_fmt["thought"] = (
                    node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0]
                )
                if isinstance(node_fmt["action"], str):
                    node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0]
    else:
        if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt:
            message["content"] = None
            fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0]
            if isinstance(fmt["action"], str):
                fmt["action"] = fmt["action"].split("<|action_end|>")[0]
        for key in ["node"]:
            fmt.pop(key, None)
    return dict(current_node=current_node, response=message)


async def run(request: GenerationParams, _request: Request):
    async def generate():
        try:
            queue = janus.Queue()
            stop_event = asyncio.Event()

            # Wrapping a sync generator as an async generator using run_in_executor
            def sync_generator_wrapper():
                try:
                    for response in agent(inputs, session_id=session_id):
                        queue.sync_q.put(response)
                except Exception as e:
                    logging.exception(f"Exception in sync_generator_wrapper: {e}")
                finally:
                    # Notify async_generator_wrapper that the data generation is complete.
                    queue.sync_q.put(None)

            async def async_generator_wrapper():
                loop = asyncio.get_event_loop()
                loop.run_in_executor(None, sync_generator_wrapper)
                while True:
                    response = await queue.async_q.get()
                    if response is None:  # Ensure that all elements are consumed
                        break
                    yield response
                stop_event.set()  # Inform sync_generator_wrapper to stop

            async for message in async_generator_wrapper():
                response_json = json.dumps(
                    _postprocess_agent_message(message.model_dump()),
                    ensure_ascii=False,
                )
                yield {"data": response_json}
                if await _request.is_disconnected():
                    break
        except Exception as exc:
            msg = "An error occurred while generating the response."
            logging.exception(msg)
            response_json = json.dumps(
                dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
            )
            yield {"data": response_json}
        finally:
            await stop_event.wait()  # Waiting for async_generator_wrapper to stop
            queue.close()
            await queue.wait_closed()
            agent.agent.memory.memory_map.pop(session_id, None)

    inputs = request.inputs
    session_id = request.session_id
    agent = init_agent(
        lang=args.lang,
        model_format=args.model_format,
        search_engine=args.search_engine,
    )
    return EventSourceResponse(generate(), ping=300)


async def run_async(request: GenerationParams, _request: Request):
    async def generate():
        try:
            async for message in agent(inputs, session_id=session_id):
                response_json = json.dumps(
                    _postprocess_agent_message(message.model_dump()),
                    ensure_ascii=False,
                )
                yield {"data": response_json}
                if await _request.is_disconnected():
                    break
        except Exception as exc:
            msg = "An error occurred while generating the response."
            logging.exception(msg)
            response_json = json.dumps(
                dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
            )
            yield {"data": response_json}
        finally:
            agent.agent.memory.memory_map.pop(session_id, None)

    inputs = request.inputs
    session_id = request.session_id
    agent = init_agent(
        lang=args.lang,
        model_format=args.model_format,
        search_engine=args.search_engine,
        use_async=True,
    )
    return EventSourceResponse(generate(), ping=300)


app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"])

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")