timeki commited on
Commit
c3b815e
1 Parent(s): 14a5a97

Add OpenAlex papers recommandation

Browse files
app.py CHANGED
@@ -8,6 +8,7 @@ from sentence_transformers import CrossEncoder
8
  oa = OpenAlex()
9
 
10
  import gradio as gr
 
11
  import pandas as pd
12
  import numpy as np
13
  import os
@@ -42,10 +43,12 @@ from climateqa.sample_questions import QUESTIONS
42
  from climateqa.constants import POSSIBLE_REPORTS
43
  from climateqa.utils import get_image_from_azure_blob_storage
44
  from climateqa.engine.keywords import make_keywords_chain
45
- # from climateqa.engine.chains.answer_rag import make_rag_papers_chain
 
 
46
  from climateqa.engine.graph import make_graph_agent,display_graph
47
 
48
- from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox
49
 
50
  # Load environment variables in local mode
51
  try:
@@ -106,7 +109,7 @@ CITATION_TEXT = r"""@misc{climateqa,
106
  # Create vectorstore and retriever
107
  vectorstore = get_pinecone_vectorstore(embeddings_function)
108
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
109
- reranker = get_reranker("large")
110
  agent = make_graph_agent(llm,vectorstore,reranker)
111
 
112
 
@@ -326,13 +329,11 @@ def generate_keywords(query):
326
 
327
 
328
  papers_cols_widths = {
329
- "doc":50,
330
  "id":100,
331
  "title":300,
332
  "doi":100,
333
  "publication_year":100,
334
  "abstract":500,
335
- "rerank_score":100,
336
  "is_oa":50,
337
  }
338
 
@@ -340,6 +341,62 @@ papers_cols = list(papers_cols_widths.keys())
340
  papers_cols_widths = list(papers_cols_widths.values())
341
 
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  # --------------------------------------------------------------------
344
  # Gradio
345
  # --------------------------------------------------------------------
@@ -429,7 +486,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
429
  samples.append(group_examples)
430
 
431
 
432
- with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
433
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
434
  docs_textbox = gr.State("")
435
 
@@ -437,36 +494,29 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
437
 
438
 
439
  # with Modal(visible = False) as config_modal:
440
- with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
443
-
444
-
445
- dropdown_sources = gr.CheckboxGroup(
446
- ["IPCC", "IPBES","IPOS"],
447
- label="Select source",
448
- value=["IPCC"],
449
- interactive=True,
450
- )
451
-
452
- dropdown_reports = gr.Dropdown(
453
- POSSIBLE_REPORTS,
454
- label="Or select specific reports",
455
- multiselect=True,
456
- value=None,
457
- interactive=True,
458
- )
459
-
460
- dropdown_audience = gr.Dropdown(
461
- ["Children","General public","Experts"],
462
- label="Select audience",
463
- value="Experts",
464
- interactive=True,
465
- )
466
-
467
- output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
468
- output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
469
 
 
470
 
471
  with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
472
  with Modal(visible=False, elem_id="modal_figure_galery") as modal:
@@ -490,6 +540,38 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
490
  # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
491
  # gallery_component = gr.Gallery(object_fit='cover')
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
494
 
495
  # with gr.Row():
@@ -546,6 +628,21 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
546
 
547
  def finish_chat():
548
  return (gr.update(interactive = True,value = ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
  (textbox
551
  .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
@@ -570,6 +667,11 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
570
 
571
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
572
 
 
 
 
 
 
573
 
574
  demo.queue()
575
 
 
8
  oa = OpenAlex()
9
 
10
  import gradio as gr
11
+ from gradio_modal import Modal
12
  import pandas as pd
13
  import numpy as np
14
  import os
 
43
  from climateqa.constants import POSSIBLE_REPORTS
44
  from climateqa.utils import get_image_from_azure_blob_storage
45
  from climateqa.engine.keywords import make_keywords_chain
46
+ from climateqa.engine.chains.answer_rag import make_rag_papers_chain
47
+ from climateqa.engine.graph import make_graph_agent,display_graph
48
+
49
  from climateqa.engine.graph import make_graph_agent,display_graph
50
 
51
+ from front.utils import make_html_source, make_html_figure_sources,parse_output_llm_with_sources,serialize_docs,make_toolbox,make_html_df
52
 
53
  # Load environment variables in local mode
54
  try:
 
109
  # Create vectorstore and retriever
110
  vectorstore = get_pinecone_vectorstore(embeddings_function)
111
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
112
+ reranker = get_reranker("nano")
113
  agent = make_graph_agent(llm,vectorstore,reranker)
114
 
115
 
 
329
 
330
 
331
  papers_cols_widths = {
 
332
  "id":100,
333
  "title":300,
334
  "doi":100,
335
  "publication_year":100,
336
  "abstract":500,
 
337
  "is_oa":50,
338
  }
339
 
 
341
  papers_cols_widths = list(papers_cols_widths.values())
342
 
343
 
344
+ async def find_papers(query,after):
345
+
346
+ summary = ""
347
+ keywords = generate_keywords(query)
348
+ df_works = oa.search(keywords,after = after)
349
+ df_works = df_works.dropna(subset=["abstract"])
350
+ df_works = oa.rerank(query,df_works,reranker)
351
+ df_works = df_works.sort_values("rerank_score",ascending=False)
352
+ docs_html = []
353
+ for i in range(10):
354
+ docs_html.append(make_html_df(df_works, i))
355
+ docs_html = "".join(docs_html)
356
+ print(docs_html)
357
+ G = oa.make_network(df_works)
358
+
359
+ height = "750px"
360
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
361
+ network_html = network.generate_html()
362
+
363
+ network_html = network_html.replace("'", "\"")
364
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
365
+ network_html = network_html + css_to_inject
366
+
367
+
368
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
369
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
370
+ allow-scripts allow-same-origin allow-popups
371
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
372
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
373
+
374
+
375
+ docs = df_works["content"].head(10).tolist()
376
+
377
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
378
+ df_works["doc"] = df_works["doc"] + 1
379
+ df_works = df_works[papers_cols]
380
+
381
+ yield docs_html, network_html, summary
382
+
383
+ chain = make_rag_papers_chain(llm)
384
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
385
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
386
+
387
+ async for op in result:
388
+
389
+ op = op.ops[0]
390
+
391
+ if op['path'] == path_answer: # reforulated question
392
+ new_token = op['value'] # str
393
+ summary += new_token
394
+ else:
395
+ continue
396
+ yield docs_html, network_html, summary
397
+
398
+
399
+
400
  # --------------------------------------------------------------------
401
  # Gradio
402
  # --------------------------------------------------------------------
 
486
  samples.append(group_examples)
487
 
488
 
489
+ with gr.Tab("Sources",elem_id = "tab-sources",id = 1):
490
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
491
  docs_textbox = gr.State("")
492
 
 
494
 
495
 
496
  # with Modal(visible = False) as config_modal:
497
+
498
+
499
+ with gr.Tab("Papers",elem_id = "tab-citations",id = 4):
500
+ btn_summary = gr.Button("Summary")
501
+ # Fenêtre simulée pour le Summary
502
+ with gr.Group(visible=False, elem_id="papers-summary-popup") as summary_popup:
503
+ papers_summary = gr.Markdown("### Summary Content", visible=True, elem_id="papers-summary")
504
+
505
+ btn_relevant_papers = gr.Button("Relevant papers")
506
+ # Fenêtre simulée pour les Relevant Papers
507
+ with gr.Group(visible=False, elem_id="papers-relevant-popup") as relevant_popup:
508
+ papers_html = gr.HTML(show_label=False, elem_id="sources-textbox")
509
+ docs_textbox = gr.State("")
510
+
511
+ btn_citations_network = gr.Button("Citations network")
512
+ # Fenêtre simulée pour le Citations Network
513
+ with Modal(visible=False) as modal:
514
+ citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
515
+ btn_citations_network.click(lambda: Modal(visible=True), None, modal)
516
 
517
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
+
520
 
521
  with gr.Tab("Figures",elem_id = "tab-figures",id = 3):
522
  with Modal(visible=False, elem_id="modal_figure_galery") as modal:
 
540
  # with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
541
  # gallery_component = gr.Gallery(object_fit='cover')
542
 
543
+ with gr.Tab("Settings",elem_id = "tab-config",id = 2):
544
+
545
+ gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
546
+
547
+
548
+ dropdown_sources = gr.CheckboxGroup(
549
+ ["IPCC", "IPBES","IPOS", "OpenAlex"],
550
+ label="Select source",
551
+ value=["IPCC"],
552
+ interactive=True,
553
+ )
554
+
555
+ dropdown_reports = gr.Dropdown(
556
+ POSSIBLE_REPORTS,
557
+ label="Or select specific reports",
558
+ multiselect=True,
559
+ value=None,
560
+ interactive=True,
561
+ )
562
+
563
+ dropdown_audience = gr.Dropdown(
564
+ ["Children","General public","Experts"],
565
+ label="Select audience",
566
+ value="Experts",
567
+ interactive=True,
568
+ )
569
+
570
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
571
+
572
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
573
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
574
+
575
  # with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
576
 
577
  # with gr.Row():
 
628
 
629
  def finish_chat():
630
  return (gr.update(interactive = True,value = ""))
631
+
632
+ # Initialize visibility states
633
+ summary_visible = False
634
+ relevant_visible = False
635
+
636
+ # Functions to toggle visibility
637
+ def toggle_summary_visibility():
638
+ global summary_visible
639
+ summary_visible = not summary_visible
640
+ return gr.update(visible=summary_visible)
641
+
642
+ def toggle_relevant_visibility():
643
+ global relevant_visible
644
+ relevant_visible = not relevant_visible
645
+ return gr.update(visible=relevant_visible)
646
 
647
  (textbox
648
  .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
 
667
 
668
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
669
 
670
+ textbox.submit(find_papers,[textbox,after], [papers_html,citations_network,papers_summary])
671
+ examples_hidden.change(find_papers,[examples_hidden,after], [papers_html,citations_network,papers_summary])
672
+
673
+ btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
674
+ btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
675
 
676
  demo.queue()
677
 
climateqa/engine/chains/answer_rag.py CHANGED
@@ -7,6 +7,8 @@ from langchain_core.prompts.base import format_document
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
 
 
10
 
11
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
12
 
@@ -68,32 +70,32 @@ def make_rag_node(llm,with_docs = True):
68
 
69
 
70
 
71
- # def make_rag_papers_chain(llm):
72
 
73
- # prompt = ChatPromptTemplate.from_template(papers_prompt_template)
74
- # input_documents = {
75
- # "context":lambda x : _combine_documents(x["docs"]),
76
- # **pass_values(["question","language"])
77
- # }
78
 
79
- # chain = input_documents | prompt | llm | StrOutputParser()
80
- # chain = rename_chain(chain,"answer")
81
 
82
- # return chain
83
 
84
 
85
 
86
 
87
 
88
 
89
- # def make_illustration_chain(llm):
90
 
91
- # prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
92
 
93
- # input_description_images = {
94
- # "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
95
- # **pass_values(["question","audience","language","answer"]),
96
- # }
97
 
98
- # illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
99
- # return illustration_chain
 
7
 
8
  from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
9
  from climateqa.engine.chains.prompts import papers_prompt_template
10
+ from ..utils import rename_chain, pass_values
11
+
12
 
13
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
14
 
 
70
 
71
 
72
 
73
+ def make_rag_papers_chain(llm):
74
 
75
+ prompt = ChatPromptTemplate.from_template(papers_prompt_template)
76
+ input_documents = {
77
+ "context":lambda x : _combine_documents(x["docs"]),
78
+ **pass_values(["question","language"])
79
+ }
80
 
81
+ chain = input_documents | prompt | llm | StrOutputParser()
82
+ chain = rename_chain(chain,"answer")
83
 
84
+ return chain
85
 
86
 
87
 
88
 
89
 
90
 
91
+ def make_illustration_chain(llm):
92
 
93
+ prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
94
 
95
+ input_description_images = {
96
+ "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
97
+ **pass_values(["question","audience","language","answer"]),
98
+ }
99
 
100
+ illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
101
+ return illustration_chain
climateqa/knowledge/openalex.py CHANGED
@@ -62,11 +62,10 @@ class OpenAlex():
62
 
63
  scores = reranker.rank(
64
  query,
65
- df["content"].tolist(),
66
- top_k = len(df),
67
  )
68
- scores.sort(key = lambda x : x["corpus_id"])
69
- scores = [x["score"] for x in scores]
70
  df["rerank_score"] = scores
71
  return df
72
 
 
62
 
63
  scores = reranker.rank(
64
  query,
65
+ df["content"].tolist()
 
66
  )
67
+ scores = sorted(scores.results, key = lambda x : x.document.doc_id)
68
+ scores = [x.score for x in scores]
69
  df["rerank_score"] = scores
70
  return df
71
 
front/utils.py CHANGED
@@ -108,6 +108,28 @@ def make_html_source(source,i):
108
  return card
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def make_html_figure_sources(source,i,img_str):
112
  meta = source.metadata
113
  content = source.page_content.strip()
 
108
  return card
109
 
110
 
111
+ def make_html_df(df,i):
112
+ title = df['title'][i]
113
+ content = df['abstract'][i]
114
+ url = df['doi'][i]
115
+ publication_date = df['publication_year'][i]
116
+
117
+ card = f"""
118
+ <div class="card" id="doc{i}">
119
+ <div class="card-content">
120
+ <h2>Doc {i+1} - {title}</h2>
121
+ <p>{content}</p>
122
+ </div>
123
+ <div class="card-footer">
124
+ <span>{publication_date}</span>
125
+ <a href="{url}" target="_blank" class="pdf-link">
126
+ </div>
127
+ </div>
128
+ """
129
+
130
+ return card
131
+
132
+
133
  def make_html_figure_sources(source,i,img_str):
134
  meta = source.metadata
135
  content = source.page_content.strip()