Add OpenAlex papers recommandation
Browse files- app.py +136 -34
- climateqa/engine/chains/answer_rag.py +19 -17
- climateqa/knowledge/openalex.py +3 -4
- front/utils.py +22 -0
app.py
CHANGED
@@ -8,6 +8,7 @@ from sentence_transformers import CrossEncoder
|
|
8 |
oa = OpenAlex()
|
9 |
|
10 |
import gradio as gr
|
|
|
11 |
import pandas as pd
|
12 |
import numpy as np
|
13 |
import os
|
@@ -42,10 +43,12 @@ from climateqa.sample_questions import QUESTIONS
|
|
42 |
from climateqa.constants import POSSIBLE_REPORTS
|
43 |
from climateqa.utils import get_image_from_azure_blob_storage
|
44 |
from climateqa.engine.keywords import make_keywords_chain
|
45 |
-
|
|
|
|
|
46 |
from climateqa.engine.graph import make_graph_agent,display_graph
|
47 |
|
48 |
-
from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox
|
49 |
|
50 |
# Load environment variables in local mode
|
51 |
try:
|
@@ -106,7 +109,7 @@ CITATION_TEXT = r"""@misc{climateqa,
|
|
106 |
# Create vectorstore and retriever
|
107 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
108 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
109 |
-
reranker = get_reranker("
|
110 |
agent = make_graph_agent(llm,vectorstore,reranker)
|
111 |
|
112 |
|
@@ -326,13 +329,11 @@ def generate_keywords(query):
|
|
326 |
|
327 |
|
328 |
papers_cols_widths = {
|
329 |
-
"doc":50,
|
330 |
"id":100,
|
331 |
"title":300,
|
332 |
"doi":100,
|
333 |
"publication_year":100,
|
334 |
"abstract":500,
|
335 |
-
"rerank_score":100,
|
336 |
"is_oa":50,
|
337 |
}
|
338 |
|
@@ -340,6 +341,62 @@ papers_cols = list(papers_cols_widths.keys())
|
|
340 |
papers_cols_widths = list(papers_cols_widths.values())
|
341 |
|
342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
# --------------------------------------------------------------------
|
344 |
# Gradio
|
345 |
# --------------------------------------------------------------------
|
@@ -429,7 +486,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
429 |
samples.append(group_examples)
|
430 |
|
431 |
|
432 |
-
with gr.Tab("Sources",elem_id = "tab-
|
433 |
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
434 |
docs_textbox = gr.State("")
|
435 |
|
@@ -437,36 +494,29 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
437 |
|
438 |
|
439 |
# with Modal(visible = False) as config_modal:
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
dropdown_sources = gr.CheckboxGroup(
|
446 |
-
["IPCC", "IPBES","IPOS"],
|
447 |
-
label="Select source",
|
448 |
-
value=["IPCC"],
|
449 |
-
interactive=True,
|
450 |
-
)
|
451 |
-
|
452 |
-
dropdown_reports = gr.Dropdown(
|
453 |
-
POSSIBLE_REPORTS,
|
454 |
-
label="Or select specific reports",
|
455 |
-
multiselect=True,
|
456 |
-
value=None,
|
457 |
-
interactive=True,
|
458 |
-
)
|
459 |
-
|
460 |
-
dropdown_audience = gr.Dropdown(
|
461 |
-
["Children","General public","Experts"],
|
462 |
-
label="Select audience",
|
463 |
-
value="Experts",
|
464 |
-
interactive=True,
|
465 |
-
)
|
466 |
-
|
467 |
-
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
468 |
-
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
469 |
|
|
|
470 |
|
471 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
|
472 |
with Modal(visible=False, elem_id="modal_figure_galery") as modal:
|
@@ -490,6 +540,38 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
490 |
# with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
491 |
# gallery_component = gr.Gallery(object_fit='cover')
|
492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
494 |
|
495 |
# with gr.Row():
|
@@ -546,6 +628,21 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
546 |
|
547 |
def finish_chat():
|
548 |
return (gr.update(interactive = True,value = ""))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
(textbox
|
551 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
@@ -570,6 +667,11 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
570 |
|
571 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
572 |
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
demo.queue()
|
575 |
|
|
|
8 |
oa = OpenAlex()
|
9 |
|
10 |
import gradio as gr
|
11 |
+
from gradio_modal import Modal
|
12 |
import pandas as pd
|
13 |
import numpy as np
|
14 |
import os
|
|
|
43 |
from climateqa.constants import POSSIBLE_REPORTS
|
44 |
from climateqa.utils import get_image_from_azure_blob_storage
|
45 |
from climateqa.engine.keywords import make_keywords_chain
|
46 |
+
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
47 |
+
from climateqa.engine.graph import make_graph_agent,display_graph
|
48 |
+
|
49 |
from climateqa.engine.graph import make_graph_agent,display_graph
|
50 |
|
51 |
+
from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox,make_html_df
|
52 |
|
53 |
# Load environment variables in local mode
|
54 |
try:
|
|
|
109 |
# Create vectorstore and retriever
|
110 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
111 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
112 |
+
reranker = get_reranker("nano")
|
113 |
agent = make_graph_agent(llm,vectorstore,reranker)
|
114 |
|
115 |
|
|
|
329 |
|
330 |
|
331 |
papers_cols_widths = {
|
|
|
332 |
"id":100,
|
333 |
"title":300,
|
334 |
"doi":100,
|
335 |
"publication_year":100,
|
336 |
"abstract":500,
|
|
|
337 |
"is_oa":50,
|
338 |
}
|
339 |
|
|
|
341 |
papers_cols_widths = list(papers_cols_widths.values())
|
342 |
|
343 |
|
344 |
+
async def find_papers(query,after):
|
345 |
+
|
346 |
+
summary = ""
|
347 |
+
keywords = generate_keywords(query)
|
348 |
+
df_works = oa.search(keywords,after = after)
|
349 |
+
df_works = df_works.dropna(subset=["abstract"])
|
350 |
+
df_works = oa.rerank(query,df_works,reranker)
|
351 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
352 |
+
docs_html = []
|
353 |
+
for i in range(10):
|
354 |
+
docs_html.append(make_html_df(df_works, i))
|
355 |
+
docs_html = "".join(docs_html)
|
356 |
+
print(docs_html)
|
357 |
+
G = oa.make_network(df_works)
|
358 |
+
|
359 |
+
height = "750px"
|
360 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
361 |
+
network_html = network.generate_html()
|
362 |
+
|
363 |
+
network_html = network_html.replace("'", "\"")
|
364 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
365 |
+
network_html = network_html + css_to_inject
|
366 |
+
|
367 |
+
|
368 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
369 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
370 |
+
allow-scripts allow-same-origin allow-popups
|
371 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
372 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
373 |
+
|
374 |
+
|
375 |
+
docs = df_works["content"].head(10).tolist()
|
376 |
+
|
377 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
378 |
+
df_works["doc"] = df_works["doc"] + 1
|
379 |
+
df_works = df_works[papers_cols]
|
380 |
+
|
381 |
+
yield docs_html, network_html, summary
|
382 |
+
|
383 |
+
chain = make_rag_papers_chain(llm)
|
384 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
385 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
386 |
+
|
387 |
+
async for op in result:
|
388 |
+
|
389 |
+
op = op.ops[0]
|
390 |
+
|
391 |
+
if op['path'] == path_answer: # reforulated question
|
392 |
+
new_token = op['value'] # str
|
393 |
+
summary += new_token
|
394 |
+
else:
|
395 |
+
continue
|
396 |
+
yield docs_html, network_html, summary
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
# --------------------------------------------------------------------
|
401 |
# Gradio
|
402 |
# --------------------------------------------------------------------
|
|
|
486 |
samples.append(group_examples)
|
487 |
|
488 |
|
489 |
+
with gr.Tab("Sources",elem_id = "tab-sources",id = 1):
|
490 |
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
|
491 |
docs_textbox = gr.State("")
|
492 |
|
|
|
494 |
|
495 |
|
496 |
# with Modal(visible = False) as config_modal:
|
497 |
+
|
498 |
+
|
499 |
+
with gr.Tab("Papers",elem_id = "tab-citations",id = 4):
|
500 |
+
btn_summary = gr.Button("Summary")
|
501 |
+
# Fenêtre simulée pour le Summary
|
502 |
+
with gr.Group(visible=False, elem_id="papers-summary-popup") as summary_popup:
|
503 |
+
papers_summary = gr.Markdown("### Summary Content", visible=True, elem_id="papers-summary")
|
504 |
+
|
505 |
+
btn_relevant_papers = gr.Button("Relevant papers")
|
506 |
+
# Fenêtre simulée pour les Relevant Papers
|
507 |
+
with gr.Group(visible=False, elem_id="papers-relevant-popup") as relevant_popup:
|
508 |
+
papers_html = gr.HTML(show_label=False, elem_id="sources-textbox")
|
509 |
+
docs_textbox = gr.State("")
|
510 |
+
|
511 |
+
btn_citations_network = gr.Button("Citations network")
|
512 |
+
# Fenêtre simulée pour le Citations Network
|
513 |
+
with Modal(visible=False) as modal:
|
514 |
+
citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
|
515 |
+
btn_citations_network.click(lambda: Modal(visible=True), None, modal)
|
516 |
|
517 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
+
|
520 |
|
521 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
|
522 |
with Modal(visible=False, elem_id="modal_figure_galery") as modal:
|
|
|
540 |
# with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
541 |
# gallery_component = gr.Gallery(object_fit='cover')
|
542 |
|
543 |
+
with gr.Tab("Settings",elem_id = "tab-config",id = 2):
|
544 |
+
|
545 |
+
gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
|
546 |
+
|
547 |
+
|
548 |
+
dropdown_sources = gr.CheckboxGroup(
|
549 |
+
["IPCC", "IPBES","IPOS", "OpenAlex"],
|
550 |
+
label="Select source",
|
551 |
+
value=["IPCC"],
|
552 |
+
interactive=True,
|
553 |
+
)
|
554 |
+
|
555 |
+
dropdown_reports = gr.Dropdown(
|
556 |
+
POSSIBLE_REPORTS,
|
557 |
+
label="Or select specific reports",
|
558 |
+
multiselect=True,
|
559 |
+
value=None,
|
560 |
+
interactive=True,
|
561 |
+
)
|
562 |
+
|
563 |
+
dropdown_audience = gr.Dropdown(
|
564 |
+
["Children","General public","Experts"],
|
565 |
+
label="Select audience",
|
566 |
+
value="Experts",
|
567 |
+
interactive=True,
|
568 |
+
)
|
569 |
+
|
570 |
+
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
|
571 |
+
|
572 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
573 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
574 |
+
|
575 |
# with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
|
576 |
|
577 |
# with gr.Row():
|
|
|
628 |
|
629 |
def finish_chat():
|
630 |
return (gr.update(interactive = True,value = ""))
|
631 |
+
|
632 |
+
# Initialize visibility states
|
633 |
+
summary_visible = False
|
634 |
+
relevant_visible = False
|
635 |
+
|
636 |
+
# Functions to toggle visibility
|
637 |
+
def toggle_summary_visibility():
|
638 |
+
global summary_visible
|
639 |
+
summary_visible = not summary_visible
|
640 |
+
return gr.update(visible=summary_visible)
|
641 |
+
|
642 |
+
def toggle_relevant_visibility():
|
643 |
+
global relevant_visible
|
644 |
+
relevant_visible = not relevant_visible
|
645 |
+
return gr.update(visible=relevant_visible)
|
646 |
|
647 |
(textbox
|
648 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
|
|
667 |
|
668 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
669 |
|
670 |
+
textbox.submit(find_papers,[textbox,after], [papers_html,citations_network,papers_summary])
|
671 |
+
examples_hidden.change(find_papers,[examples_hidden,after], [papers_html,citations_network,papers_summary])
|
672 |
+
|
673 |
+
btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
|
674 |
+
btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
|
675 |
|
676 |
demo.queue()
|
677 |
|
climateqa/engine/chains/answer_rag.py
CHANGED
@@ -7,6 +7,8 @@ from langchain_core.prompts.base import format_document
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
|
|
|
|
10 |
|
11 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
12 |
|
@@ -68,32 +70,32 @@ def make_rag_node(llm,with_docs = True):
|
|
68 |
|
69 |
|
70 |
|
71 |
-
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
|
82 |
-
|
83 |
|
84 |
|
85 |
|
86 |
|
87 |
|
88 |
|
89 |
-
|
90 |
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
-
|
99 |
-
|
|
|
7 |
|
8 |
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
|
9 |
from climateqa.engine.chains.prompts import papers_prompt_template
|
10 |
+
from ..utils import rename_chain, pass_values
|
11 |
+
|
12 |
|
13 |
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
14 |
|
|
|
70 |
|
71 |
|
72 |
|
73 |
+
def make_rag_papers_chain(llm):
|
74 |
|
75 |
+
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
|
76 |
+
input_documents = {
|
77 |
+
"context":lambda x : _combine_documents(x["docs"]),
|
78 |
+
**pass_values(["question","language"])
|
79 |
+
}
|
80 |
|
81 |
+
chain = input_documents | prompt | llm | StrOutputParser()
|
82 |
+
chain = rename_chain(chain,"answer")
|
83 |
|
84 |
+
return chain
|
85 |
|
86 |
|
87 |
|
88 |
|
89 |
|
90 |
|
91 |
+
def make_illustration_chain(llm):
|
92 |
|
93 |
+
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
|
94 |
|
95 |
+
input_description_images = {
|
96 |
+
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
|
97 |
+
**pass_values(["question","audience","language","answer"]),
|
98 |
+
}
|
99 |
|
100 |
+
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
|
101 |
+
return illustration_chain
|
climateqa/knowledge/openalex.py
CHANGED
@@ -62,11 +62,10 @@ class OpenAlex():
|
|
62 |
|
63 |
scores = reranker.rank(
|
64 |
query,
|
65 |
-
df["content"].tolist()
|
66 |
-
top_k = len(df),
|
67 |
)
|
68 |
-
scores.
|
69 |
-
scores = [x
|
70 |
df["rerank_score"] = scores
|
71 |
return df
|
72 |
|
|
|
62 |
|
63 |
scores = reranker.rank(
|
64 |
query,
|
65 |
+
df["content"].tolist()
|
|
|
66 |
)
|
67 |
+
scores = sorted(scores.results, key = lambda x : x.document.doc_id)
|
68 |
+
scores = [x.score for x in scores]
|
69 |
df["rerank_score"] = scores
|
70 |
return df
|
71 |
|
front/utils.py
CHANGED
@@ -108,6 +108,28 @@ def make_html_source(source,i):
|
|
108 |
return card
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def make_html_figure_sources(source,i,img_str):
|
112 |
meta = source.metadata
|
113 |
content = source.page_content.strip()
|
|
|
108 |
return card
|
109 |
|
110 |
|
111 |
+
def make_html_df(df,i):
|
112 |
+
title = df['title'][i]
|
113 |
+
content = df['abstract'][i]
|
114 |
+
url = df['doi'][i]
|
115 |
+
publication_date = df['publication_year'][i]
|
116 |
+
|
117 |
+
card = f"""
|
118 |
+
<div class="card" id="doc{i}">
|
119 |
+
<div class="card-content">
|
120 |
+
<h2>Doc {i+1} - {title}</h2>
|
121 |
+
<p>{content}</p>
|
122 |
+
</div>
|
123 |
+
<div class="card-footer">
|
124 |
+
<span>{publication_date}</span>
|
125 |
+
<a href="{url}" target="_blank" class="pdf-link">
|
126 |
+
</div>
|
127 |
+
</div>
|
128 |
+
"""
|
129 |
+
|
130 |
+
return card
|
131 |
+
|
132 |
+
|
133 |
def make_html_figure_sources(source,i,img_str):
|
134 |
meta = source.metadata
|
135 |
content = source.page_content.strip()
|