File size: 2,970 Bytes
107f987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys
import os

sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint


import models.shared as shared

from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from typing import List, Set



class CustomLLMSingleActionAgent(ZeroShotAgent):
    allowed_tools: List[str]

    def __init__(self, *args, **kwargs):
        super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
        self.allowed_tools = kwargs['allowed_tools']

    def get_allowed_tools(self) -> Set[str]:
        return set(self.allowed_tools)


async def dispatch(args: Namespace):
    args_dict = vars(args)

    shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
    llm_model_ins = shared.loaderLLM()

    template = """This is a conversation between a human and a bot:
    
{chat_history}

Write a summary of the conversation for {input}:
"""

    prompt = PromptTemplate(
        input_variables=["input", "chat_history"],
        template=template
    )
    memory = ConversationBufferMemory(memory_key="chat_history")
    readonlymemory = ReadOnlySharedMemory(memory=memory)
    summry_chain = LLMChain(
        llm=llm_model_ins,
        prompt=prompt,
        verbose=True,
        memory=readonlymemory,  # use the read-only memory to prevent the tool from modifying the memory
    )


    tools = [
        Tool(
            name="Summary",
            func=summry_chain.run,
            description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
        )
    ]

    prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
    suffix = """Begin!
     
Question: {input}
{agent_scratchpad}"""


    prompt = CustomLLMSingleActionAgent.create_prompt(
        tools,
        prefix=prefix,
        suffix=suffix,
        input_variables=["input",   "agent_scratchpad"]
    )
    tool_names = [tool.name for tool in tools]
    llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
    agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
    agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)

    agent_chain.run(input="你好")
    agent_chain.run(input="你是谁?")
    agent_chain.run(input="我们之前聊了什么?")

if __name__ == '__main__':
    args = None
    args = parser.parse_args(args=['--model-dir', '/media/checkpoint/',  '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(dispatch(args))