timeki commited on
Commit
363fe2e
1 Parent(s): 7ec5d9e

move code from papers in separate file

Browse files
app.py CHANGED
@@ -1,11 +1,9 @@
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
- from climateqa.knowledge.openalex import OpenAlex
5
  from sentence_transformers import CrossEncoder
6
 
7
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
8
- oa = OpenAlex()
9
 
10
  import gradio as gr
11
  from gradio_modal import Modal
@@ -44,10 +42,9 @@ from climateqa.engine.chains.prompts import audience_prompts
44
  from climateqa.sample_questions import QUESTIONS
45
  from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
46
  from climateqa.utils import get_image_from_azure_blob_storage
47
- from climateqa.engine.keywords import make_keywords_chain
48
- from climateqa.engine.chains.answer_rag import make_rag_papers_chain
49
  from climateqa.engine.graph import make_graph_agent
50
  from climateqa.engine.embeddings import get_embeddings_function
 
51
 
52
  from front.utils import serialize_docs,process_figures,make_html_df
53
 
@@ -249,84 +246,9 @@ def log_on_azure(file, logs, share_client):
249
  file_client.upload_file(logs)
250
 
251
 
252
- def generate_keywords(query):
253
- chain = make_keywords_chain(llm)
254
- keywords = chain.invoke(query)
255
- keywords = " AND ".join(keywords["keywords"])
256
- return keywords
257
 
258
 
259
 
260
- papers_cols_widths = {
261
- "id":100,
262
- "title":300,
263
- "doi":100,
264
- "publication_year":100,
265
- "abstract":500,
266
- "is_oa":50,
267
- }
268
-
269
- papers_cols = list(papers_cols_widths.keys())
270
- papers_cols_widths = list(papers_cols_widths.values())
271
-
272
-
273
- async def find_papers(query,after, relevant_content_sources):
274
- if "OpenAlex" in relevant_content_sources:
275
- summary = ""
276
- keywords = generate_keywords(query)
277
- df_works = oa.search(keywords,after = after)
278
- df_works = df_works.dropna(subset=["abstract"])
279
- df_works = oa.rerank(query,df_works,reranker)
280
- df_works = df_works.sort_values("rerank_score",ascending=False)
281
- docs_html = []
282
- for i in range(10):
283
- docs_html.append(make_html_df(df_works, i))
284
- docs_html = "".join(docs_html)
285
- print(docs_html)
286
- G = oa.make_network(df_works)
287
-
288
- height = "750px"
289
- network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
290
- network_html = network.generate_html()
291
-
292
- network_html = network_html.replace("'", "\"")
293
- css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
294
- network_html = network_html + css_to_inject
295
-
296
-
297
- network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
298
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
299
- allow-scripts allow-same-origin allow-popups
300
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
301
- allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
302
-
303
-
304
- docs = df_works["content"].head(10).tolist()
305
-
306
- df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
307
- df_works["doc"] = df_works["doc"] + 1
308
- df_works = df_works[papers_cols]
309
-
310
- yield docs_html, network_html, summary
311
-
312
- chain = make_rag_papers_chain(llm)
313
- result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
314
- path_answer = "/logs/StrOutputParser/streamed_output/-"
315
-
316
- async for op in result:
317
-
318
- op = op.ops[0]
319
-
320
- if op['path'] == path_answer: # reforulated question
321
- new_token = op['value'] # str
322
- summary += new_token
323
- else:
324
- continue
325
- yield docs_html, network_html, summary
326
- else :
327
- yield "","", ""
328
-
329
-
330
  # --------------------------------------------------------------------
331
  # Gradio
332
  # --------------------------------------------------------------------
@@ -430,7 +352,10 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
430
  with gr.Tab("Configuration", id = 10, ) as tab_config:
431
  gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
432
 
 
 
433
  with gr.Row():
 
434
  dropdown_sources = gr.CheckboxGroup(
435
  ["IPCC", "IPBES","IPOS"],
436
  label="Select source",
@@ -443,7 +368,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
443
  value=["IPCC figures"],
444
  interactive=True,
445
  )
446
-
447
  dropdown_reports = gr.Dropdown(
448
  POSSIBLE_REPORTS,
449
  label="Or select specific reports",
@@ -452,6 +377,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
452
  interactive=True,
453
  )
454
 
 
 
 
455
  dropdown_audience = gr.Dropdown(
456
  ["Children","General public","Experts"],
457
  label="Select audience",
 
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
 
4
  from sentence_transformers import CrossEncoder
5
 
6
  # reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
 
7
 
8
  import gradio as gr
9
  from gradio_modal import Modal
 
42
  from climateqa.sample_questions import QUESTIONS
43
  from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
44
  from climateqa.utils import get_image_from_azure_blob_storage
 
 
45
  from climateqa.engine.graph import make_graph_agent
46
  from climateqa.engine.embeddings import get_embeddings_function
47
+ from climateqa.engine.chains.retrieve_papers import find_papers
48
 
49
  from front.utils import serialize_docs,process_figures,make_html_df
50
 
 
246
  file_client.upload_file(logs)
247
 
248
 
 
 
 
 
 
249
 
250
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  # --------------------------------------------------------------------
253
  # Gradio
254
  # --------------------------------------------------------------------
 
352
  with gr.Tab("Configuration", id = 10, ) as tab_config:
353
  gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
354
 
355
+
356
+
357
  with gr.Row():
358
+
359
  dropdown_sources = gr.CheckboxGroup(
360
  ["IPCC", "IPBES","IPOS"],
361
  label="Select source",
 
368
  value=["IPCC figures"],
369
  interactive=True,
370
  )
371
+
372
  dropdown_reports = gr.Dropdown(
373
  POSSIBLE_REPORTS,
374
  label="Or select specific reports",
 
377
  interactive=True,
378
  )
379
 
380
+ search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat")
381
+
382
+
383
  dropdown_audience = gr.Dropdown(
384
  ["Children","General public","Experts"],
385
  label="Select audience",
climateqa/engine/chains/retrieve_papers.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from climateqa.engine.keywords import make_keywords_chain
2
+ from climateqa.engine.llm import get_llm
3
+ from climateqa.knowledge.openalex import OpenAlex
4
+ from climateqa.engine.chains.answer_rag import make_rag_papers_chain
5
+ from front.utils import make_html_df
6
+ from climateqa.engine.reranker import get_reranker
7
+
8
+ oa = OpenAlex()
9
+
10
+ llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
11
+ reranker = get_reranker("nano")
12
+
13
+
14
+ papers_cols_widths = {
15
+ "id":100,
16
+ "title":300,
17
+ "doi":100,
18
+ "publication_year":100,
19
+ "abstract":500,
20
+ "is_oa":50,
21
+ }
22
+
23
+ papers_cols = list(papers_cols_widths.keys())
24
+ papers_cols_widths = list(papers_cols_widths.values())
25
+
26
+
27
+
28
+ def generate_keywords(query):
29
+ chain = make_keywords_chain(llm)
30
+ keywords = chain.invoke(query)
31
+ keywords = " AND ".join(keywords["keywords"])
32
+ return keywords
33
+
34
+
35
+ async def find_papers(query,after, relevant_content_sources, reranker= reranker):
36
+ if "OpenAlex" in relevant_content_sources:
37
+ summary = ""
38
+ keywords = generate_keywords(query)
39
+ df_works = oa.search(keywords,after = after)
40
+
41
+ print(f"Found {len(df_works)} papers")
42
+
43
+ if not df_works.empty:
44
+ df_works = df_works.dropna(subset=["abstract"])
45
+ df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
46
+ df_works = oa.rerank(query,df_works,reranker)
47
+ df_works = df_works.sort_values("rerank_score",ascending=False)
48
+ docs_html = []
49
+ for i in range(10):
50
+ docs_html.append(make_html_df(df_works, i))
51
+ docs_html = "".join(docs_html)
52
+ G = oa.make_network(df_works)
53
+
54
+ height = "750px"
55
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
56
+ network_html = network.generate_html()
57
+
58
+ network_html = network_html.replace("'", "\"")
59
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
60
+ network_html = network_html + css_to_inject
61
+
62
+
63
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
64
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
65
+ allow-scripts allow-same-origin allow-popups
66
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
67
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
68
+
69
+
70
+ docs = df_works["content"].head(10).tolist()
71
+
72
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
73
+ df_works["doc"] = df_works["doc"] + 1
74
+ df_works = df_works[papers_cols]
75
+
76
+ yield docs_html, network_html, summary
77
+
78
+ chain = make_rag_papers_chain(llm)
79
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
80
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
81
+
82
+ async for op in result:
83
+
84
+ op = op.ops[0]
85
+
86
+ if op['path'] == path_answer: # reforulated question
87
+ new_token = op['value'] # str
88
+ summary += new_token
89
+ else:
90
+ continue
91
+ yield docs_html, network_html, summary
92
+ else :
93
+ print("No papers found")
94
+ else :
95
+ yield "","", ""
climateqa/engine/keywords.py CHANGED
@@ -11,10 +11,12 @@ class KeywordsOutput(BaseModel):
11
 
12
  keywords: list = Field(
13
  description="""
14
- Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
 
15
 
16
  Example:
17
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
 
18
  - "How will El Nino be impacted by climate change" -> ["el nino"]
19
  - "Is climate change a hoax" -> [Climate change","hoax"]
20
  """
 
11
 
12
  keywords: list = Field(
13
  description="""
14
+ Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
15
+ Do not use special characters or accents.
16
 
17
  Example:
18
  - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
19
+ - "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
20
  - "How will El Nino be impacted by climate change" -> ["el nino"]
21
  - "Is climate change a hoax" -> [Climate change","hoax"]
22
  """
climateqa/knowledge/openalex.py CHANGED
@@ -41,6 +41,10 @@ class OpenAlex():
41
  break
42
 
43
  df_works = pd.DataFrame(page)
 
 
 
 
44
  df_works = df_works.dropna(subset = ["title"])
45
  df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
46
  df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
 
41
  break
42
 
43
  df_works = pd.DataFrame(page)
44
+
45
+ if df_works.empty:
46
+ return df_works
47
+
48
  df_works = df_works.dropna(subset = ["title"])
49
  df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
50
  df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")