Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import logging | |
import random | |
from typing import Dict, List, Union | |
import janus | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.requests import Request | |
from pydantic import BaseModel, Field | |
from sse_starlette.sse import EventSourceResponse | |
from mindsearch.agent import init_agent | |
def parse_arguments(): | |
import argparse | |
parser = argparse.ArgumentParser(description="MindSearch API") | |
parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host") | |
parser.add_argument("--port", default=8002, type=int, help="Service port") | |
parser.add_argument("--lang", default="cn", type=str, help="Language") | |
parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format") | |
parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine") | |
parser.add_argument("--asy", default=False, action="store_true", help="Agent mode") | |
return parser.parse_args() | |
args = parse_arguments() | |
app = FastAPI(docs_url="/") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class GenerationParams(BaseModel): | |
inputs: Union[str, List[Dict]] | |
session_id: int = Field(default_factory=lambda: random.randint(0, 999999)) | |
agent_cfg: Dict = dict() | |
def _postprocess_agent_message(message: dict) -> dict: | |
content, fmt = message["content"], message["formatted"] | |
current_node = content["current_node"] if isinstance(content, dict) else None | |
if current_node: | |
message["content"] = None | |
for key in ["ref2url"]: | |
fmt.pop(key, None) | |
graph = fmt["node"] | |
for key in graph.copy(): | |
if key != current_node: | |
graph.pop(key) | |
if current_node not in ["root", "response"]: | |
node = graph[current_node] | |
for key in ["memory", "session_id"]: | |
node.pop(key, None) | |
node_fmt = node["response"]["formatted"] | |
if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt: | |
node["response"]["content"] = None | |
node_fmt["thought"] = ( | |
node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0] | |
) | |
if isinstance(node_fmt["action"], str): | |
node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0] | |
else: | |
if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt: | |
message["content"] = None | |
fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0] | |
if isinstance(fmt["action"], str): | |
fmt["action"] = fmt["action"].split("<|action_end|>")[0] | |
for key in ["node"]: | |
fmt.pop(key, None) | |
return dict(current_node=current_node, response=message) | |
async def run(request: GenerationParams, _request: Request): | |
async def generate(): | |
try: | |
queue = janus.Queue() | |
stop_event = asyncio.Event() | |
# Wrapping a sync generator as an async generator using run_in_executor | |
def sync_generator_wrapper(): | |
try: | |
for response in agent(inputs, session_id=session_id): | |
queue.sync_q.put(response) | |
except Exception as e: | |
logging.exception(f"Exception in sync_generator_wrapper: {e}") | |
finally: | |
# Notify async_generator_wrapper that the data generation is complete. | |
queue.sync_q.put(None) | |
async def async_generator_wrapper(): | |
loop = asyncio.get_event_loop() | |
loop.run_in_executor(None, sync_generator_wrapper) | |
while True: | |
response = await queue.async_q.get() | |
if response is None: # Ensure that all elements are consumed | |
break | |
yield response | |
stop_event.set() # Inform sync_generator_wrapper to stop | |
async for message in async_generator_wrapper(): | |
response_json = json.dumps( | |
_postprocess_agent_message(message.model_dump()), | |
ensure_ascii=False, | |
) | |
yield {"data": response_json} | |
if await _request.is_disconnected(): | |
break | |
except Exception as exc: | |
msg = "An error occurred while generating the response." | |
logging.exception(msg) | |
response_json = json.dumps( | |
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False | |
) | |
yield {"data": response_json} | |
finally: | |
await stop_event.wait() # Waiting for async_generator_wrapper to stop | |
queue.close() | |
await queue.wait_closed() | |
agent.agent.memory.memory_map.pop(session_id, None) | |
inputs = request.inputs | |
session_id = request.session_id | |
agent = init_agent( | |
lang=args.lang, | |
model_format=args.model_format, | |
search_engine=args.search_engine, | |
) | |
return EventSourceResponse(generate(), ping=300) | |
async def run_async(request: GenerationParams, _request: Request): | |
async def generate(): | |
try: | |
async for message in agent(inputs, session_id=session_id): | |
response_json = json.dumps( | |
_postprocess_agent_message(message.model_dump()), | |
ensure_ascii=False, | |
) | |
yield {"data": response_json} | |
if await _request.is_disconnected(): | |
break | |
except Exception as exc: | |
msg = "An error occurred while generating the response." | |
logging.exception(msg) | |
response_json = json.dumps( | |
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False | |
) | |
yield {"data": response_json} | |
finally: | |
agent.agent.memory.memory_map.pop(session_id, None) | |
inputs = request.inputs | |
session_id = request.session_id | |
agent = init_agent( | |
lang=args.lang, | |
model_format=args.model_format, | |
search_engine=args.search_engine, | |
use_async=True, | |
) | |
return EventSourceResponse(generate(), ping=300) | |
app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"]) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |