Spaces:
Sleeping
Sleeping
File size: 6,671 Bytes
f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a f3f614f dc9e27a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import asyncio
import json
import logging
import random
from typing import Dict, List, Union
import janus
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from mindsearch.agent import init_agent
def parse_arguments():
import argparse
parser = argparse.ArgumentParser(description="MindSearch API")
parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host")
parser.add_argument("--port", default=8002, type=int, help="Service port")
parser.add_argument("--lang", default="cn", type=str, help="Language")
parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format")
parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine")
parser.add_argument("--asy", default=False, action="store_true", help="Agent mode")
return parser.parse_args()
args = parse_arguments()
app = FastAPI(docs_url="/")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class GenerationParams(BaseModel):
inputs: Union[str, List[Dict]]
session_id: int = Field(default_factory=lambda: random.randint(0, 999999))
agent_cfg: Dict = dict()
def _postprocess_agent_message(message: dict) -> dict:
content, fmt = message["content"], message["formatted"]
current_node = content["current_node"] if isinstance(content, dict) else None
if current_node:
message["content"] = None
for key in ["ref2url"]:
fmt.pop(key, None)
graph = fmt["node"]
for key in graph.copy():
if key != current_node:
graph.pop(key)
if current_node not in ["root", "response"]:
node = graph[current_node]
for key in ["memory", "session_id"]:
node.pop(key, None)
node_fmt = node["response"]["formatted"]
if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt:
node["response"]["content"] = None
node_fmt["thought"] = (
node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0]
)
if isinstance(node_fmt["action"], str):
node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0]
else:
if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt:
message["content"] = None
fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0]
if isinstance(fmt["action"], str):
fmt["action"] = fmt["action"].split("<|action_end|>")[0]
for key in ["node"]:
fmt.pop(key, None)
return dict(current_node=current_node, response=message)
async def run(request: GenerationParams, _request: Request):
async def generate():
try:
queue = janus.Queue()
stop_event = asyncio.Event()
# Wrapping a sync generator as an async generator using run_in_executor
def sync_generator_wrapper():
try:
for response in agent(inputs, session_id=session_id):
queue.sync_q.put(response)
except Exception as e:
logging.exception(f"Exception in sync_generator_wrapper: {e}")
finally:
# Notify async_generator_wrapper that the data generation is complete.
queue.sync_q.put(None)
async def async_generator_wrapper():
loop = asyncio.get_event_loop()
loop.run_in_executor(None, sync_generator_wrapper)
while True:
response = await queue.async_q.get()
if response is None: # Ensure that all elements are consumed
break
yield response
stop_event.set() # Inform sync_generator_wrapper to stop
async for message in async_generator_wrapper():
response_json = json.dumps(
_postprocess_agent_message(message.model_dump()),
ensure_ascii=False,
)
yield {"data": response_json}
if await _request.is_disconnected():
break
except Exception as exc:
msg = "An error occurred while generating the response."
logging.exception(msg)
response_json = json.dumps(
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
)
yield {"data": response_json}
finally:
await stop_event.wait() # Waiting for async_generator_wrapper to stop
queue.close()
await queue.wait_closed()
agent.agent.memory.memory_map.pop(session_id, None)
inputs = request.inputs
session_id = request.session_id
agent = init_agent(
lang=args.lang,
model_format=args.model_format,
search_engine=args.search_engine,
)
return EventSourceResponse(generate(), ping=300)
async def run_async(request: GenerationParams, _request: Request):
async def generate():
try:
async for message in agent(inputs, session_id=session_id):
response_json = json.dumps(
_postprocess_agent_message(message.model_dump()),
ensure_ascii=False,
)
yield {"data": response_json}
if await _request.is_disconnected():
break
except Exception as exc:
msg = "An error occurred while generating the response."
logging.exception(msg)
response_json = json.dumps(
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
)
yield {"data": response_json}
finally:
agent.agent.memory.memory_map.pop(session_id, None)
inputs = request.inputs
session_id = request.session_id
agent = init_agent(
lang=args.lang,
model_format=args.model_format,
search_engine=args.search_engine,
use_async=True,
)
return EventSourceResponse(generate(), ping=300)
app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"])
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|