|
import sys |
|
import os |
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../../') |
|
import asyncio |
|
from argparse import Namespace |
|
from models.loader.args import parser |
|
from models.loader import LoaderCheckPoint |
|
|
|
|
|
import models.shared as shared |
|
|
|
from langchain.chains import LLMChain |
|
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory |
|
from langchain.prompts import PromptTemplate |
|
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor |
|
from typing import List, Set |
|
|
|
|
|
|
|
class CustomLLMSingleActionAgent(ZeroShotAgent): |
|
allowed_tools: List[str] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super(CustomLLMSingleActionAgent, self).__init__(*args, **kwargs) |
|
self.allowed_tools = kwargs['allowed_tools'] |
|
|
|
def get_allowed_tools(self) -> Set[str]: |
|
return set(self.allowed_tools) |
|
|
|
|
|
async def dispatch(args: Namespace): |
|
args_dict = vars(args) |
|
|
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) |
|
llm_model_ins = shared.loaderLLM() |
|
|
|
template = """This is a conversation between a human and a bot: |
|
|
|
{chat_history} |
|
|
|
Write a summary of the conversation for {input}: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["input", "chat_history"], |
|
template=template |
|
) |
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
readonlymemory = ReadOnlySharedMemory(memory=memory) |
|
summry_chain = LLMChain( |
|
llm=llm_model_ins, |
|
prompt=prompt, |
|
verbose=True, |
|
memory=readonlymemory, |
|
) |
|
|
|
|
|
tools = [ |
|
Tool( |
|
name="Summary", |
|
func=summry_chain.run, |
|
description="useful for when you summarize a conversation. The input to this tool should be a string, representing who will read this summary." |
|
) |
|
] |
|
|
|
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" |
|
suffix = """Begin! |
|
|
|
Question: {input} |
|
{agent_scratchpad}""" |
|
|
|
|
|
prompt = CustomLLMSingleActionAgent.create_prompt( |
|
tools, |
|
prefix=prefix, |
|
suffix=suffix, |
|
input_variables=["input", "agent_scratchpad"] |
|
) |
|
tool_names = [tool.name for tool in tools] |
|
llm_chain = LLMChain(llm=llm_model_ins, prompt=prompt) |
|
agent = CustomLLMSingleActionAgent(llm_chain=llm_chain, tools=tools, allowed_tools=tool_names) |
|
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools) |
|
|
|
agent_chain.run(input="你好") |
|
agent_chain.run(input="你是谁?") |
|
agent_chain.run(input="我们之前聊了什么?") |
|
|
|
if __name__ == '__main__': |
|
args = None |
|
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'vicuna-13b-hf', '--no-remote-model', '--load-in-8bit']) |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
loop.run_until_complete(dispatch(args)) |
|
|