import gradio as gr
import os
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain import PromptTemplate, HuggingFaceHub, LLMChain, ConversationChain
from langchain.llms import OpenAI
from langchain.chains.conversation.memory import ConversationBufferMemory
from threading import Lock
import openai
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError
from typing import Optional, Tuple
TOOLS_DEFAULT_LIST = ['serpapi', 'news-api', 'pal-math']
MAX_TOKENS = 512
PROMPT_TEMPLATE = PromptTemplate(
input_variables=["original_words"],
template="Restate the following: \n{original_words}\n",
)
BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!"
AUTH_ERR_MSG = "Please paste your OpenAI key."
news_api_key = os.environ["NEWS_API_KEY"]
def run_chain(chain, inp, capture_hidden_text):
output = ""
hidden_text = None
try:
output = chain.run(input=inp)
except AuthenticationError as ae:
output = AUTH_ERR_MSG
except RateLimitError as rle:
output = "\n\nRateLimitError: " + str(rle)
except ValueError as ve:
output = "\n\nValueError: " + str(ve)
except InvalidRequestError as ire:
output = "\n\nInvalidRequestError: " + str(ire)
except Exception as e:
output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e)
return output, hidden_text
def transform_text(desc, express_chain):
formatted_prompt = PROMPT_TEMPLATE.format(
original_words=desc
)
generated_text = desc
# replace all newlines with
in generated_text
generated_text = generated_text.replace("\n", "\n\n")
return generated_text
class ChatWrapper:
def __init__(self):
self.lock = Lock()
def __call__(
self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], express_chain: Optional[LLMChain]):
"""Execute the chat functionality."""
self.lock.acquire()
try:
history = history or []
# If chain is None, that is because no API key was provided.
output = "Please paste your OpenAI key to use this application."
hidden_text = output
if chain and chain != "":
# Set OpenAI key
openai.api_key = api_key
output, hidden_text = run_chain(chain, inp, capture_hidden_text=False)
print('output1', output)
output = transform_text(output, express_chain)
print('output2', output)
text_to_display = output
history.append((inp, text_to_display))
except Exception as e:
raise e
finally:
self.lock.release()
# return history, history, html_video, temp_file, ""
return history, history
chat = ChatWrapper()
def load_chain(tools_list, llm):
chain = None
express_chain = None
print("\ntools_list", tools_list)
tool_names = tools_list
tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key)
memory = ConversationBufferMemory(memory_key="chat_history")
chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory)
express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
return chain, express_chain
def set_openai_api_key(api_key):
"""Set the api key and return chain.
If no api_key, then None is returned.
"""
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
os.environ["OPENAI_API_KEY"] = ""
return chain, express_chain, llm
with gr.Blocks() as app:
llm_state = gr.State()
history_state = gr.State()
chain_state = gr.State()
express_chain_state = gr.State()
with gr.Row():
with gr.Column():
gr.HTML(
"""