CongMa / test /models /test_vicuna_chain_agent.py
XuBailing's picture
Upload 243 files
107f987
raw
history blame
2.97 kB
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../')
import asyncio
from argparse import Namespace
from models.loader.args import parser
from models.loader import LoaderCheckPoint
import models.shared as shared
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from typing import List, Set
class CustomLLMSingleActionAgent(ZeroShotAgent):
allowed_tools: List[str]
def __init__(self, *args, **kwargs):
super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs)
self.allowed_tools = kwargs['allowed_tools']
def get_allowed_tools(self) -> Set[str]:
return set(self.allowed_tools)
async def dispatch(args: Namespace):
args_dict = vars(args)
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
llm_model_ins = shared.loaderLLM()
template = """This is a conversation between a human and a bot:
{chat_history}
Write a summary of the conversation for {input}:
"""
prompt = PromptTemplate(
input_variables=["input", "chat_history"],
template=template
)
memory = ConversationBufferMemory(memory_key="chat_history")
readonlymemory = ReadOnlySharedMemory(memory=memory)
summry_chain = LLMChain(
llm=llm_model_ins,
prompt=prompt,
verbose=True,
memory=readonlymemory, # use the read-only memory to prevent the tool from modifying the memory
)
tools = [
Tool(
name="Summary",
func=summry_chain.run,
description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary."
)
]
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
suffix = """Begin!
Question: {input}
{agent_scratchpad}"""
prompt = CustomLLMSingleActionAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"]
)
tool_names = [tool.name for tool in tools]
llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt)
agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names)
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools)
agent_chain.run(input="你好")
agent_chain.run(input="你是谁?")
agent_chain.run(input="我们之前聊了什么?")
if __name__ == '__main__':
args = None
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit'])
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(dispatch(args))