microstronger commited on
Commit
2449c50
1 Parent(s): 9a3998f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +136 -0
  2. requirements.txt +12 -0
  3. terminal.py +50 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from copy import deepcopy
5
+ from dataclasses import asdict
6
+ from typing import Dict, List, Union
7
+
8
+ import janus
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from lagent.schema import AgentStatusCode
12
+ from pydantic import BaseModel
13
+ from sse_starlette.sse import EventSourceResponse
14
+
15
+ from mindsearch.agent import init_agent
16
+
17
+
18
+ def parse_arguments():
19
+ import argparse
20
+ parser = argparse.ArgumentParser(description='MindSearch API')
21
+ parser.add_argument('--lang', default='cn', type=str, help='Language')
22
+ parser.add_argument('--model_format',
23
+ default='internlm_server',
24
+ type=str,
25
+ help='Model format')
26
+ parser.add_argument('--search_engine',
27
+ default='DuckDuckGoSearch',
28
+ type=str,
29
+ help='Search engine')
30
+ return parser.parse_args()
31
+
32
+
33
+ args = parse_arguments()
34
+ app = FastAPI(docs_url='/')
35
+
36
+ app.add_middleware(CORSMiddleware,
37
+ allow_origins=['*'],
38
+ allow_credentials=True,
39
+ allow_methods=['*'],
40
+ allow_headers=['*'])
41
+
42
+
43
+ class GenerationParams(BaseModel):
44
+ inputs: Union[str, List[Dict]]
45
+ agent_cfg: Dict = dict()
46
+
47
+
48
+ @app.post('/solve')
49
+ async def run(request: GenerationParams):
50
+
51
+ def convert_adjacency_to_tree(adjacency_input, root_name):
52
+
53
+ def build_tree(node_name):
54
+ node = {'name': node_name, 'children': []}
55
+ if node_name in adjacency_input:
56
+ for child in adjacency_input[node_name]:
57
+ child_node = build_tree(child['name'])
58
+ child_node['state'] = child['state']
59
+ child_node['id'] = child['id']
60
+ node['children'].append(child_node)
61
+ return node
62
+
63
+ return build_tree(root_name)
64
+
65
+ async def generate():
66
+ try:
67
+ queue = janus.Queue()
68
+ stop_event = asyncio.Event()
69
+
70
+ # Wrapping a sync generator as an async generator using run_in_executor
71
+ def sync_generator_wrapper():
72
+ try:
73
+ for response in agent.stream_chat(inputs):
74
+ queue.sync_q.put(response)
75
+ except Exception as e:
76
+ logging.exception(
77
+ f'Exception in sync_generator_wrapper: {e}')
78
+ finally:
79
+ # Notify async_generator_wrapper that the data generation is complete.
80
+ queue.sync_q.put(None)
81
+
82
+ async def async_generator_wrapper():
83
+ loop = asyncio.get_event_loop()
84
+ loop.run_in_executor(None, sync_generator_wrapper)
85
+ while True:
86
+ response = await queue.async_q.get()
87
+ if response is None: # Ensure that all elements are consumed
88
+ break
89
+ yield response
90
+ if not isinstance(
91
+ response,
92
+ tuple) and response.state == AgentStatusCode.END:
93
+ break
94
+ stop_event.set() # Inform sync_generator_wrapper to stop
95
+
96
+ async for response in async_generator_wrapper():
97
+ if isinstance(response, tuple):
98
+ agent_return, node_name = response
99
+ else:
100
+ agent_return = response
101
+ node_name = None
102
+ origin_adj = deepcopy(agent_return.adjacency_list)
103
+ adjacency_list = convert_adjacency_to_tree(
104
+ agent_return.adjacency_list, 'root')
105
+ assert adjacency_list[
106
+ 'name'] == 'root' and 'children' in adjacency_list
107
+ agent_return.adjacency_list = adjacency_list['children']
108
+ agent_return = asdict(agent_return)
109
+ agent_return['adj'] = origin_adj
110
+ response_json = json.dumps(dict(response=agent_return,
111
+ current_node=node_name),
112
+ ensure_ascii=False)
113
+ yield {'data': response_json}
114
+ # yield f'data: {response_json}\n\n'
115
+ except Exception as exc:
116
+ msg = 'An error occurred while generating the response.'
117
+ logging.exception(msg)
118
+ response_json = json.dumps(
119
+ dict(error=dict(msg=msg, details=str(exc))),
120
+ ensure_ascii=False)
121
+ yield {'data': response_json}
122
+ # yield f'data: {response_json}\n\n'
123
+ finally:
124
+ await stop_event.wait(
125
+ ) # Waiting for async_generator_wrapper to stop
126
+ queue.close()
127
+ await queue.wait_closed()
128
+
129
+ inputs = request.inputs
130
+ agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
131
+ return EventSourceResponse(generate())
132
+
133
+
134
+ if __name__ == '__main__':
135
+ import uvicorn
136
+ uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ duckduckgo_search==5.3.1b1
2
+ einops
3
+ fastapi
4
+ git+https://github.com/InternLM/lagent.git
5
+ gradio
6
+ janus
7
+ lmdeploy
8
+ pyvis
9
+ sse-starlette
10
+ termcolor
11
+ transformers==4.41.0
12
+ uvicorn
terminal.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ from lagent.actions import ActionExecutor, BingBrowser
4
+ from lagent.llms import INTERNLM2_META, LMDeployServer
5
+
6
+ from mindsearch.agent.mindsearch_agent import (MindSearchAgent,
7
+ MindSearchProtocol)
8
+ from mindsearch.agent.mindsearch_prompt import (
9
+ FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
10
+ searcher_context_template_cn, searcher_context_template_en,
11
+ searcher_input_template_cn, searcher_input_template_en,
12
+ searcher_system_prompt_cn, searcher_system_prompt_en)
13
+
14
+ lang = 'cn'
15
+ llm = LMDeployServer(path='internlm/internlm2_5-7b-chat',
16
+ model_name='internlm2',
17
+ meta_template=INTERNLM2_META,
18
+ top_p=0.8,
19
+ top_k=1,
20
+ temperature=0,
21
+ max_new_tokens=8192,
22
+ repetition_penalty=1.02,
23
+ stop_words=['<|im_end|>'])
24
+
25
+ agent = MindSearchAgent(
26
+ llm=llm,
27
+ protocol=MindSearchProtocol(
28
+ meta_prompt=datetime.now().strftime('The current date is %Y-%m-%d.'),
29
+ interpreter_prompt=GRAPH_PROMPT_CN
30
+ if lang == 'cn' else GRAPH_PROMPT_EN,
31
+ response_prompt=FINAL_RESPONSE_CN
32
+ if lang == 'cn' else FINAL_RESPONSE_EN),
33
+ searcher_cfg=dict(
34
+ llm=llm,
35
+ plugin_executor=ActionExecutor(
36
+ BingBrowser(searcher_type='DuckDuckGoSearch', topk=6)),
37
+ protocol=MindSearchProtocol(
38
+ meta_prompt=datetime.now().strftime(
39
+ 'The current date is %Y-%m-%d.'),
40
+ plugin_prompt=searcher_system_prompt_cn
41
+ if lang == 'cn' else searcher_system_prompt_en,
42
+ ),
43
+ template=dict(input=searcher_input_template_cn
44
+ if lang == 'cn' else searcher_input_template_en,
45
+ context=searcher_context_template_cn
46
+ if lang == 'cn' else searcher_context_template_en)),
47
+ max_turn=10)
48
+
49
+ for agent_return in agent.stream_chat('上海今天适合穿什么衣服'):
50
+ pass