Spaces:
Sleeping
Sleeping
import os | |
from datetime import datetime | |
import gradio as gr | |
from pinecone import Pinecone | |
from huggingface_hub import whoami | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain_groq import ChatGroq | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.vectorstores import Pinecone as PineconeVectorstore | |
from celsius_csrd_chatbot.utils import ( | |
make_html_source, | |
make_pairs, | |
_format_chat_history, | |
_combine_documents, | |
init_env, | |
parse_output_llm_with_sources, | |
) | |
from celsius_csrd_chatbot.agent import make_graph_agent, display_graph | |
init_env() | |
demo_name = "ESRS_QA" | |
hf_model = "BAAI/bge-base-en-v1.5" | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name=hf_model, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
index = pc.Index(os.getenv("PINECONE_API_INDEX")) | |
vectorstore = PineconeVectorstore(index, embeddings, "page_content") | |
llm = ChatGroq(temperature=0, model_name="llama-3.2-90b-text-preview") | |
agent = make_graph_agent(llm, vectorstore) | |
memory = ConversationBufferMemory( | |
return_messages=True, output_key="answer", input_key="question" | |
) | |
async def chat(query, history): | |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: | |
(messages in gradio format, messages in langchain format, source documents)""" | |
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print(f">> NEW QUESTION ({date_now}) : {query}") | |
inputs = {"query": query} | |
result = agent.astream_events(inputs, version="v1") | |
docs = [] | |
docs_html = "" | |
output_query = "" | |
start_streaming = False | |
steps_display = { | |
"categorize_esrs": ("🔄️ Analyzing user query", True), | |
"retrieve_documents": ("🔄️ Searching in the knowledge base", True), | |
} | |
try: | |
async for event in result: | |
print(event) | |
if event["event"] == "on_chat_model_stream": | |
# print("line 66") | |
if start_streaming == False: | |
# print("line 68") | |
start_streaming = True | |
history[-1] = (query, "") | |
new_token = event["data"]["chunk"].content | |
previous_answer = history[-1][1] | |
previous_answer = previous_answer if previous_answer is not None else "" | |
answer_yet = previous_answer + new_token | |
answer_yet = parse_output_llm_with_sources(answer_yet) | |
history[-1] = (query, answer_yet) | |
elif ( | |
event["name"] == "answer_rag_wrong" | |
and event["event"] == "on_chain_stream" | |
): | |
history[-1] = (query, event["data"]["chunk"]["answer"]) | |
elif ( | |
event["name"] == "retrieve_documents" | |
and event["event"] == "on_chain_end" | |
): | |
try: | |
# print(event) | |
# print("line 84") | |
docs = event["data"]["output"]["documents"] | |
docs_html = [] | |
for i, doc in enumerate(docs, 1): | |
docs_html.append(make_html_source(i, doc)) | |
# print(docs_html) | |
docs_html = "".join(docs_html) | |
# print(docs_html) | |
except Exception as e: | |
print(f"Error getting documents: {e}") | |
print(event) | |
for event_name, ( | |
event_description, | |
display_output, | |
) in steps_display.items(): | |
if event["name"] == event_name: | |
# print("line 99") | |
if event["event"] == "on_chain_start": | |
# print("line 101") | |
answer_yet = event_description | |
history[-1] = (query, answer_yet) | |
history = [tuple(x) for x in history] | |
yield history, docs_html | |
except Exception as e: | |
raise gr.Error(f"{e}") | |
with open("./assets/style.css", "r") as f: | |
css = f.read() | |
# Set up Gradio Theme | |
theme = gr.themes.Base( | |
primary_hue="blue", | |
secondary_hue="red", | |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
init_prompt = """ | |
Hello, I am ESRS Q&A, a conversational assistant designed to help you understand the content of European Sustainability Reporting Standards (ESRS). I will answer your questions based **on the official definition of each ESRS as well as complementary guidelines**. | |
⚠️ Limitations | |
*Please note that this chatbot is in an early stage phase, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* | |
What do you want to learn ? | |
""" | |
with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo: | |
with gr.Column(visible=True) as bloc_2: | |
with gr.Tab("ESRS Q&A"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
value=[(None, init_prompt)], | |
show_copy_button=True, | |
show_label=False, | |
elem_id="chatbot", | |
layout="panel", | |
avatar_images=( | |
None, | |
"https://i.ibb.co/cN0czLp/celsius-logo.png", | |
), | |
) | |
state = gr.State([]) | |
with gr.Row(elem_id="input-message"): | |
ask = gr.Textbox( | |
placeholder="Ask me anything here!", | |
show_label=False, | |
scale=7, | |
lines=1, | |
interactive=True, | |
elem_id="input-textbox", | |
) | |
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
with gr.Tab("Sources", elem_id="tab-citations", id=1): | |
sources_textbox = gr.HTML( | |
show_label=False, elem_id="sources-textbox" | |
) | |
docs_textbox = gr.State("") | |
with gr.Tab("About", elem_classes="max-height other-tabs"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("WIP") | |
def start_chat(query, history): | |
history = history + [(query, None)] | |
history = [tuple(x) for x in history] | |
return (gr.update(interactive=False), history) | |
def finish_chat(): | |
return gr.update(interactive=True, value="") | |
ask.submit( | |
start_chat, | |
[ask, chatbot], | |
[ask, chatbot], | |
queue=False, | |
api_name="start_chat_textbox", | |
).then( | |
fn=chat, | |
inputs=[ | |
ask, | |
chatbot, | |
], | |
outputs=[chatbot, sources_textbox], | |
).then( | |
finish_chat, None, [ask], api_name="finish_chat_textbox" | |
) | |
demo.launch( | |
share=True, | |
debug=True, | |
) | |