Merge branch 'add_openalex_papers' into feature/graph_recommandation
Browse files- app.py +135 -9
- climateqa/engine/chains/answer_rag.py +19 -17
- climateqa/knowledge/openalex.py +3 -4
- front/utils.py +22 -0
app.py
CHANGED
@@ -8,6 +8,7 @@ from sentence_transformers import CrossEncoder
|
|
8 |
oa = OpenAlex()
|
9 |
|
10 |
import gradio as gr
|
|
|
11 |
import pandas as pd
|
12 |
import numpy as np
|
13 |
import os
|
@@ -44,11 +45,11 @@ from climateqa.sample_questions import QUESTIONS
|
|
44 |
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
45 |
from climateqa.utils import get_image_from_azure_blob_storage
|
46 |
from climateqa.engine.keywords import make_keywords_chain
|
47 |
-
|
48 |
-
from climateqa.engine.graph import make_graph_agent
|
49 |
from climateqa.engine.embeddings import get_embeddings_function
|
50 |
|
51 |
-
from front.utils import serialize_docs,process_figures
|
52 |
|
53 |
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
|
54 |
|
@@ -115,9 +116,7 @@ vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name =
|
|
115 |
|
116 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
117 |
reranker = get_reranker("nano")
|
118 |
-
# agent = make_graph_agent(llm,vectorstore,reranker)
|
119 |
|
120 |
-
# agent = make_graph_agent(llm,vectorstore,reranker)
|
121 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
122 |
|
123 |
|
@@ -248,13 +247,11 @@ def generate_keywords(query):
|
|
248 |
|
249 |
|
250 |
papers_cols_widths = {
|
251 |
-
"doc":50,
|
252 |
"id":100,
|
253 |
"title":300,
|
254 |
"doi":100,
|
255 |
"publication_year":100,
|
256 |
"abstract":500,
|
257 |
-
"rerank_score":100,
|
258 |
"is_oa":50,
|
259 |
}
|
260 |
|
@@ -262,6 +259,62 @@ papers_cols = list(papers_cols_widths.keys())
|
|
262 |
papers_cols_widths = list(papers_cols_widths.values())
|
263 |
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
# --------------------------------------------------------------------
|
266 |
# Gradio
|
267 |
# --------------------------------------------------------------------
|
@@ -363,7 +416,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
363 |
samples.append(group_examples)
|
364 |
|
365 |
|
366 |
-
with gr.Tab("Sources",elem_id = "tab-
|
367 |
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
368 |
docs_textbox = gr.State("")
|
369 |
|
@@ -379,7 +432,28 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
379 |
show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
|
380 |
|
381 |
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content:
|
385 |
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
|
@@ -511,6 +585,38 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
511 |
# with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
512 |
# gallery_component = gr.Gallery(object_fit='cover')
|
513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
515 |
|
516 |
# with gr.Row():
|
@@ -571,6 +677,21 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
571 |
|
572 |
def finish_chat():
|
573 |
return gr.update(interactive = True,value = "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
|
575 |
|
576 |
def change_completion_status(current_state):
|
@@ -618,6 +739,11 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
618 |
|
619 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
620 |
|
|
|
|
|
|
|
|
|
|
|
621 |
|
622 |
demo.queue()
|
623 |
|
|
|
8 |
oa = OpenAlex()
|
9 |
|
10 |
import gradio as gr
|
11 |
+
from gradio_modal import Modal
|
12 |
import pandas as pd
|
13 |
import numpy as np
|
14 |
import os
|
|
|
45 |
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
46 |
from climateqa.utils import get_image_from_azure_blob_storage
|
47 |
from climateqa.engine.keywords import make_keywords_chain
|
48 |
+
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
49 |
+
from climateqa.engine.graph import make_graph_agent
|
50 |
from climateqa.engine.embeddings import get_embeddings_function
|
51 |
|
52 |
+
from front.utils import serialize_docs,process_figures,make_html_df
|
53 |
|
54 |
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
|
55 |
|
|
|
116 |
|
117 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
118 |
reranker = get_reranker("nano")
|
|
|
119 |
|
|
|
120 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
121 |
|
122 |
|
|
|
247 |
|
248 |
|
249 |
papers_cols_widths = {
|
|
|
250 |
"id":100,
|
251 |
"title":300,
|
252 |
"doi":100,
|
253 |
"publication_year":100,
|
254 |
"abstract":500,
|
|
|
255 |
"is_oa":50,
|
256 |
}
|
257 |
|
|
|
259 |
papers_cols_widths = list(papers_cols_widths.values())
|
260 |
|
261 |
|
262 |
+
async def find_papers(query,after):
|
263 |
+
|
264 |
+
summary = ""
|
265 |
+
keywords = generate_keywords(query)
|
266 |
+
df_works = oa.search(keywords,after = after)
|
267 |
+
df_works = df_works.dropna(subset=["abstract"])
|
268 |
+
df_works = oa.rerank(query,df_works,reranker)
|
269 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
270 |
+
docs_html = []
|
271 |
+
for i in range(10):
|
272 |
+
docs_html.append(make_html_df(df_works, i))
|
273 |
+
docs_html = "".join(docs_html)
|
274 |
+
print(docs_html)
|
275 |
+
G = oa.make_network(df_works)
|
276 |
+
|
277 |
+
height = "750px"
|
278 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
279 |
+
network_html = network.generate_html()
|
280 |
+
|
281 |
+
network_html = network_html.replace("'", "\"")
|
282 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
283 |
+
network_html = network_html + css_to_inject
|
284 |
+
|
285 |
+
|
286 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
287 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
288 |
+
allow-scripts allow-same-origin allow-popups
|
289 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
290 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
291 |
+
|
292 |
+
|
293 |
+
docs = df_works["content"].head(10).tolist()
|
294 |
+
|
295 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
296 |
+
df_works["doc"] = df_works["doc"] + 1
|
297 |
+
df_works = df_works[papers_cols]
|
298 |
+
|
299 |
+
yield docs_html, network_html, summary
|
300 |
+
|
301 |
+
chain = make_rag_papers_chain(llm)
|
302 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
303 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
304 |
+
|
305 |
+
async for op in result:
|
306 |
+
|
307 |
+
op = op.ops[0]
|
308 |
+
|
309 |
+
if op['path'] == path_answer: # reforulated question
|
310 |
+
new_token = op['value'] # str
|
311 |
+
summary += new_token
|
312 |
+
else:
|
313 |
+
continue
|
314 |
+
yield docs_html, network_html, summary
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
# --------------------------------------------------------------------
|
319 |
# Gradio
|
320 |
# --------------------------------------------------------------------
|
|
|
416 |
samples.append(group_examples)
|
417 |
|
418 |
|
419 |
+
with gr.Tab("Sources",elem_id = "tab-sources",id = 1):
|
420 |
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
421 |
docs_textbox = gr.State("")
|
422 |
|
|
|
432 |
show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
|
433 |
|
434 |
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
with gr.Tab("Papers",elem_id = "tab-citations",id = 5):
|
439 |
+
btn_summary = gr.Button("Summary")
|
440 |
+
# Fenêtre simulée pour le Summary
|
441 |
+
with gr.Group(visible=False, elem_id="papers-summary-popup") as summary_popup:
|
442 |
+
papers_summary = gr.Markdown("### Summary Content", visible=True, elem_id="papers-summary")
|
443 |
+
|
444 |
+
btn_relevant_papers = gr.Button("Relevant papers")
|
445 |
+
# Fenêtre simulée pour les Relevant Papers
|
446 |
+
with gr.Group(visible=False, elem_id="papers-relevant-popup") as relevant_popup:
|
447 |
+
papers_html = gr.HTML(show_label=False, elem_id="sources-textbox")
|
448 |
+
docs_textbox = gr.State("")
|
449 |
+
|
450 |
+
btn_citations_network = gr.Button("Citations network")
|
451 |
+
# Fenêtre simulée pour le Citations Network
|
452 |
+
with Modal(visible=False) as modal:
|
453 |
+
citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
|
454 |
+
btn_citations_network.click(lambda: Modal(visible=True), None, modal)
|
455 |
+
|
456 |
+
|
457 |
|
458 |
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=4) as tab_recommended_content:
|
459 |
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
|
|
|
585 |
# with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
586 |
# gallery_component = gr.Gallery(object_fit='cover')
|
587 |
|
588 |
+
with gr.Tab("Settings",elem_id = "tab-config",id = 2):
|
589 |
+
|
590 |
+
gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
|
591 |
+
|
592 |
+
|
593 |
+
dropdown_sources = gr.CheckboxGroup(
|
594 |
+
["IPCC", "IPBES","IPOS", "OpenAlex"],
|
595 |
+
label="Select source",
|
596 |
+
value=["IPCC"],
|
597 |
+
interactive=True,
|
598 |
+
)
|
599 |
+
|
600 |
+
dropdown_reports = gr.Dropdown(
|
601 |
+
POSSIBLE_REPORTS,
|
602 |
+
label="Or select specific reports",
|
603 |
+
multiselect=True,
|
604 |
+
value=None,
|
605 |
+
interactive=True,
|
606 |
+
)
|
607 |
+
|
608 |
+
dropdown_audience = gr.Dropdown(
|
609 |
+
["Children","General public","Experts"],
|
610 |
+
label="Select audience",
|
611 |
+
value="Experts",
|
612 |
+
interactive=True,
|
613 |
+
)
|
614 |
+
|
615 |
+
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
|
616 |
+
|
617 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
618 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
619 |
+
|
620 |
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
621 |
|
622 |
# with gr.Row():
|
|
|
677 |
|
678 |
def finish_chat():
|
679 |
return gr.update(interactive = True,value = "")
|
680 |
+
|
681 |
+
# Initialize visibility states
|
682 |
+
summary_visible = False
|
683 |
+
relevant_visible = False
|
684 |
+
|
685 |
+
# Functions to toggle visibility
|
686 |
+
def toggle_summary_visibility():
|
687 |
+
global summary_visible
|
688 |
+
summary_visible = not summary_visible
|
689 |
+
return gr.update(visible=summary_visible)
|
690 |
+
|
691 |
+
def toggle_relevant_visibility():
|
692 |
+
global relevant_visible
|
693 |
+
relevant_visible = not relevant_visible
|
694 |
+
return gr.update(visible=relevant_visible)
|
695 |
|
696 |
|
697 |
def change_completion_status(current_state):
|
|
|
739 |
|
740 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
741 |
|
742 |
+
textbox.submit(find_papers,[textbox,after], [papers_html,citations_network,papers_summary])
|
743 |
+
examples_hidden.change(find_papers,[examples_hidden,after], [papers_html,citations_network,papers_summary])
|
744 |
+
|
745 |
+
btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
|
746 |
+
btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
|
747 |
|
748 |
demo.queue()
|
749 |
|
climateqa/engine/chains/answer_rag.py
CHANGED
@@ -7,6 +7,8 @@ from langchain_core.prompts.base import format_document
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
|
|
|
|
10 |
|
11 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
12 |
|
@@ -71,32 +73,32 @@ def make_rag_node(llm,with_docs = True):
|
|
71 |
|
72 |
|
73 |
|
74 |
-
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
|
82 |
-
|
83 |
-
|
84 |
|
85 |
-
|
86 |
|
87 |
|
88 |
|
89 |
|
90 |
|
91 |
|
92 |
-
|
93 |
|
94 |
-
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
|
102 |
-
|
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
10 |
+
from ..utils import rename_chain, pass_values
|
11 |
+
|
12 |
|
13 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
14 |
|
|
|
73 |
|
74 |
|
75 |
|
76 |
+
def make_rag_papers_chain(llm):
|
77 |
|
78 |
+
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
79 |
+
input_documents = {
|
80 |
+
"context":lambda x : _combine_documents(x["docs"]),
|
81 |
+
**pass_values(["question","language"])
|
82 |
+
}
|
83 |
|
84 |
+
chain = input_documents | prompt | llm | StrOutputParser()
|
85 |
+
chain = rename_chain(chain,"answer")
|
86 |
|
87 |
+
return chain
|
88 |
|
89 |
|
90 |
|
91 |
|
92 |
|
93 |
|
94 |
+
def make_illustration_chain(llm):
|
95 |
|
96 |
+
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
|
97 |
|
98 |
+
input_description_images = {
|
99 |
+
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
|
100 |
+
**pass_values(["question","audience","language","answer"]),
|
101 |
+
}
|
102 |
|
103 |
+
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
104 |
+
return illustration_chain
|
climateqa/knowledge/openalex.py
CHANGED
@@ -62,11 +62,10 @@ class OpenAlex():
|
|
62 |
|
63 |
scores = reranker.rank(
|
64 |
query,
|
65 |
-
df["content"].tolist()
|
66 |
-
top_k = len(df),
|
67 |
)
|
68 |
-
scores.
|
69 |
-
scores = [x
|
70 |
df["rerank_score"] = scores
|
71 |
return df
|
72 |
|
|
|
62 |
|
63 |
scores = reranker.rank(
|
64 |
query,
|
65 |
+
df["content"].tolist()
|
|
|
66 |
)
|
67 |
+
scores = sorted(scores.results, key = lambda x : x.document.doc_id)
|
68 |
+
scores = [x.score for x in scores]
|
69 |
df["rerank_score"] = scores
|
70 |
return df
|
71 |
|
front/utils.py
CHANGED
@@ -228,6 +228,28 @@ def make_html_source(source,i):
|
|
228 |
return card
|
229 |
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def make_html_figure_sources(source,i,img_str):
|
232 |
meta = source.metadata
|
233 |
content = source.page_content.strip()
|
|
|
228 |
return card
|
229 |
|
230 |
|
231 |
+
def make_html_df(df,i):
|
232 |
+
title = df['title'][i]
|
233 |
+
content = df['abstract'][i]
|
234 |
+
url = df['doi'][i]
|
235 |
+
publication_date = df['publication_year'][i]
|
236 |
+
|
237 |
+
card = f"""
|
238 |
+
<div class="card" id="doc{i}">
|
239 |
+
<div class="card-content">
|
240 |
+
<h2>Doc {i+1} - {title}</h2>
|
241 |
+
<p>{content}</p>
|
242 |
+
</div>
|
243 |
+
<div class="card-footer">
|
244 |
+
<span>{publication_date}</span>
|
245 |
+
<a href="{url}" target="_blank" class="pdf-link">
|
246 |
+
</div>
|
247 |
+
</div>
|
248 |
+
"""
|
249 |
+
|
250 |
+
return card
|
251 |
+
|
252 |
+
|
253 |
def make_html_figure_sources(source,i,img_str):
|
254 |
meta = source.metadata
|
255 |
content = source.page_content.strip()
|