File size: 4,277 Bytes
98e07ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import ssl

import gradio as gr
import nltk
from config_utils import parse_configuration
from custom_prompt import CustomPromptGenerator
from custom_prompt_zh import ZhCustomPromptGenerator
from langchain.embeddings import ModelScopeEmbeddings
from langchain.vectorstores import FAISS
from modelscope_agent import prompt_generator_register
from modelscope_agent.agent import AgentExecutor
from modelscope_agent.agent_types import AgentType
from modelscope_agent.llm import LLMFactory
from modelscope_agent.retrieve import KnowledgeRetrieval
from modelscope_agent.tools.openapi_plugin import OpenAPIPluginTool
from modelscope_agent.utils.logger import agent_logger as logger

prompts = {
    'CustomPromptGenerator': CustomPromptGenerator,
    'ZhCustomPromptGenerator': ZhCustomPromptGenerator,
}
prompt_generator_register(prompts)

# try:
#     _create_unverified_https_context = ssl._create_unverified_context
# except AttributeError:
#     pass
# else:
#     ssl._create_default_https_context = _create_unverified_https_context
#
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')


# init user chatbot_agent
def init_user_chatbot_agent(uuid_str=''):
    builder_cfg, model_cfg, tool_cfg, available_tool_list, plugin_cfg, available_plugin_list = parse_configuration(
        uuid_str)
    # set top_p and stop_words for role play
    model_cfg[builder_cfg.model]['generate_cfg']['top_p'] = 0.5
    model_cfg[builder_cfg.model]['generate_cfg']['stop'] = 'Observation'

    # build model
    logger.info(
        uuid=uuid_str,
        message=f'using model {builder_cfg.model}',
        content={'model_config': model_cfg[builder_cfg.model]})

    # # check configuration
    # if builder_cfg.model in ['qwen-max', 'qwen-72b-api', 'qwen-14b-api', 'qwen-plus']:
    #     if 'DASHSCOPE_API_KEY' not in os.environ:
    #         raise gr.Error('DASHSCOPE_API_KEY should be set via setting environment variable')

    try:
        llm = LLMFactory.build_llm(builder_cfg.model, model_cfg)
    except Exception as e:
        raise gr.Error(str(e))

    # build prompt with zero shot react template
    prompt_generator = builder_cfg.get('prompt_generator', None)
    if builder_cfg.model.startswith('qwen') and not prompt_generator:
        prompt_generator = 'CustomPromptGenerator'
        language = builder_cfg.get('language', 'en')
        if language == 'zh':
            prompt_generator = 'ZhCustomPromptGenerator'

    prompt_cfg = {
        'prompt_generator':
        prompt_generator,
        'add_addition_round':
        True,
        'knowledge_file_name':
        os.path.basename(builder_cfg.knowledge[0]
                         if len(builder_cfg.knowledge) > 0 else ''),
        'uuid_str':
        uuid_str
    }

    # get knowledge
    # 开源版本的向量库配置
    model_id = 'damo/nlp_gte_sentence-embedding_chinese-base'
    embeddings = ModelScopeEmbeddings(model_id=model_id)
    available_knowledge_list = []
    for item in builder_cfg.knowledge:
        # if isfile and end with .txt, .md, .pdf, support only those file
        if os.path.isfile(item) and item.endswith(('.txt', '.md', '.pdf')):
            available_knowledge_list.append(item)
    if len(available_knowledge_list) > 0:
        knowledge_retrieval = KnowledgeRetrieval.from_file(
            available_knowledge_list, embeddings, FAISS)
    else:
        knowledge_retrieval = None

    additional_tool_list = add_openapi_plugin_to_additional_tool(
        plugin_cfg, available_plugin_list)
    # build agent
    agent = AgentExecutor(
        llm,
        additional_tool_list=additional_tool_list,
        tool_cfg=tool_cfg,
        agent_type=AgentType.MRKL,
        knowledge_retrieval=knowledge_retrieval,
        tool_retrieval=False,
        **prompt_cfg)
    agent.set_available_tools(available_tool_list + available_plugin_list)
    return agent


def add_openapi_plugin_to_additional_tool(plugin_cfgs, available_plugin_list):
    additional_tool_list = {}
    for name, cfg in plugin_cfgs.items():
        openapi_plugin_object = OpenAPIPluginTool(name=name, cfg=plugin_cfgs)
        additional_tool_list[name] = openapi_plugin_object
    return additional_tool_list


def user_chatbot_single_run(query, agent):
    agent.run(query)