domenicrosati commited on
Commit
964c419
β€’
1 Parent(s): dd426a1

add proper matching

Browse files
Files changed (1) hide show
  1. app.py +32 -11
app.py CHANGED
@@ -6,7 +6,7 @@ import nltk
6
  import string
7
  from streamlit.components.v1 import html
8
  from sentence_transformers.cross_encoder import CrossEncoder as CE
9
- import numpy as np
10
  from typing import List, Tuple
11
  import torch
12
 
@@ -26,7 +26,7 @@ class CrossEncoder:
26
  def remove_html(x):
27
  soup = BeautifulSoup(x, 'html.parser')
28
  text = soup.get_text()
29
- return text
30
 
31
 
32
  # 4 searches: strict y/n, supported y/n
@@ -58,7 +58,7 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
58
  except:
59
  pass
60
 
61
- contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
62
  docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
63
  for doc in req.json()['hits']]
64
 
@@ -85,10 +85,12 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
85
  )
86
 
87
 
88
- def find_source(text, docs):
89
  for doc in docs:
90
  for snippet in doc[1]:
91
  if text in remove_html(snippet.get('snippet', '')):
 
 
92
  new_text = text
93
  for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
94
  if text in sent:
@@ -98,10 +100,12 @@ def find_source(text, docs):
98
  'text': new_text,
99
  'from': snippet['source'],
100
  'supporting': snippet['target'],
101
- 'source_title': remove_html(doc[2]),
102
  'source_link': f"https://scite.ai/reports/{doc[0]}"
103
  }
104
  if text in remove_html(doc[3]):
 
 
105
  new_text = text
106
  for sent in nltk.sent_tokenize(remove_html(doc[3])):
107
  if text in sent:
@@ -111,7 +115,7 @@ def find_source(text, docs):
111
  'text': new_text,
112
  'from': doc[0],
113
  'supporting': doc[0],
114
- 'source_title': "ABSTRACT of " + remove_html(doc[2]),
115
  'source_link': f"https://scite.ai/reports/{doc[0]}"
116
  }
117
  return None
@@ -233,6 +237,22 @@ def group_results_by_context(results):
233
  return list(result_groups.values())
234
 
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def run_query(query):
237
  # if use_query_exp == 'yes':
238
  # query_exp = paraphrase(f"question2question: {query}")
@@ -278,19 +298,21 @@ def run_query(query):
278
  </div>
279
  </div>
280
  """, unsafe_allow_html=True)
 
281
  if use_reranking == 'yes':
282
  sentence_pairs = [[query, context] for context in contexts]
283
  scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
284
  hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
285
  sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
286
- context = '\n'.join(sorted_contexts[:context_limit])
287
  else:
288
- context = '\n'.join(contexts[:context_limit])
289
 
290
  results = []
291
  model_results = qa_model(question=query, context=context, top_k=10)
292
  for result in model_results:
293
- support = find_source(result['answer'], orig_docs)
 
294
  if not support:
295
  continue
296
  results.append({
@@ -316,10 +338,9 @@ def run_query(query):
316
  )
317
 
318
  for r in sorted_result:
319
- answer = r["answer"]
320
  ctx = remove_html(r["context"])
321
  for answer in r['texts']:
322
- ctx = ctx.replace(answer, f"<mark>{answer}</mark>")
323
  # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
324
  title = r.get("title", '')
325
  score = round(round(r["score"], 4) * 100, 2)
 
6
  import string
7
  from streamlit.components.v1 import html
8
  from sentence_transformers.cross_encoder import CrossEncoder as CE
9
+ import re
10
  from typing import List, Tuple
11
  import torch
12
 
 
26
  def remove_html(x):
27
  soup = BeautifulSoup(x, 'html.parser')
28
  text = soup.get_text()
29
+ return text.strip()
30
 
31
 
32
  # 4 searches: strict y/n, supported y/n
 
58
  except:
59
  pass
60
 
61
+ contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']]
62
  docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
63
  for doc in req.json()['hits']]
64
 
 
85
  )
86
 
87
 
88
+ def find_source(text, docs, matched):
89
  for doc in docs:
90
  for snippet in doc[1]:
91
  if text in remove_html(snippet.get('snippet', '')):
92
+ if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip():
93
+ continue
94
  new_text = text
95
  for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
96
  if text in sent:
 
100
  'text': new_text,
101
  'from': snippet['source'],
102
  'supporting': snippet['target'],
103
+ 'source_title': remove_html(doc[2] or ''),
104
  'source_link': f"https://scite.ai/reports/{doc[0]}"
105
  }
106
  if text in remove_html(doc[3]):
107
+ if matched and remove_html(doc[3]).strip() != matched.strip():
108
+ continue
109
  new_text = text
110
  for sent in nltk.sent_tokenize(remove_html(doc[3])):
111
  if text in sent:
 
115
  'text': new_text,
116
  'from': doc[0],
117
  'supporting': doc[0],
118
+ 'source_title': "ABSTRACT of " + remove_html(doc[2] or ''),
119
  'source_link': f"https://scite.ai/reports/{doc[0]}"
120
  }
121
  return None
 
237
  return list(result_groups.values())
238
 
239
 
240
+ def matched_context(start_i, end_i, contexts_string, seperator='---'):
241
+ # find seperators to identify start and end
242
+ doc_starts = [0]
243
+ for match in re.finditer(seperator, contexts_string):
244
+ doc_starts.append(match.end())
245
+
246
+ for i in range(len(doc_starts)):
247
+ if i == len(doc_starts) - 1:
248
+ if start_i >= doc_starts[i]:
249
+ return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '')
250
+
251
+ if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]:
252
+ return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
253
+ return None
254
+
255
+
256
  def run_query(query):
257
  # if use_query_exp == 'yes':
258
  # query_exp = paraphrase(f"question2question: {query}")
 
298
  </div>
299
  </div>
300
  """, unsafe_allow_html=True)
301
+
302
  if use_reranking == 'yes':
303
  sentence_pairs = [[query, context] for context in contexts]
304
  scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
305
  hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
306
  sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
307
+ context = '\n---'.join(sorted_contexts[:context_limit])
308
  else:
309
+ context = '\n---'.join(contexts[:context_limit])
310
 
311
  results = []
312
  model_results = qa_model(question=query, context=context, top_k=10)
313
  for result in model_results:
314
+ matched = matched_context(result['start'], result['end'], context)
315
+ support = find_source(result['answer'], orig_docs, matched)
316
  if not support:
317
  continue
318
  results.append({
 
338
  )
339
 
340
  for r in sorted_result:
 
341
  ctx = remove_html(r["context"])
342
  for answer in r['texts']:
343
+ ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
344
  # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
345
  title = r.get("title", '')
346
  score = round(round(r["score"], 4) * 100, 2)