Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
β’
964c419
1
Parent(s):
dd426a1
add proper matching
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import nltk
|
|
6 |
import string
|
7 |
from streamlit.components.v1 import html
|
8 |
from sentence_transformers.cross_encoder import CrossEncoder as CE
|
9 |
-
import
|
10 |
from typing import List, Tuple
|
11 |
import torch
|
12 |
|
@@ -26,7 +26,7 @@ class CrossEncoder:
|
|
26 |
def remove_html(x):
|
27 |
soup = BeautifulSoup(x, 'html.parser')
|
28 |
text = soup.get_text()
|
29 |
-
return text
|
30 |
|
31 |
|
32 |
# 4 searches: strict y/n, supported y/n
|
@@ -58,7 +58,7 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
58 |
except:
|
59 |
pass
|
60 |
|
61 |
-
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
|
62 |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
63 |
for doc in req.json()['hits']]
|
64 |
|
@@ -85,10 +85,12 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
85 |
)
|
86 |
|
87 |
|
88 |
-
def find_source(text, docs):
|
89 |
for doc in docs:
|
90 |
for snippet in doc[1]:
|
91 |
if text in remove_html(snippet.get('snippet', '')):
|
|
|
|
|
92 |
new_text = text
|
93 |
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
|
94 |
if text in sent:
|
@@ -98,10 +100,12 @@ def find_source(text, docs):
|
|
98 |
'text': new_text,
|
99 |
'from': snippet['source'],
|
100 |
'supporting': snippet['target'],
|
101 |
-
'source_title': remove_html(doc[2]),
|
102 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
103 |
}
|
104 |
if text in remove_html(doc[3]):
|
|
|
|
|
105 |
new_text = text
|
106 |
for sent in nltk.sent_tokenize(remove_html(doc[3])):
|
107 |
if text in sent:
|
@@ -111,7 +115,7 @@ def find_source(text, docs):
|
|
111 |
'text': new_text,
|
112 |
'from': doc[0],
|
113 |
'supporting': doc[0],
|
114 |
-
'source_title': "ABSTRACT of " + remove_html(doc[2]),
|
115 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
116 |
}
|
117 |
return None
|
@@ -233,6 +237,22 @@ def group_results_by_context(results):
|
|
233 |
return list(result_groups.values())
|
234 |
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
def run_query(query):
|
237 |
# if use_query_exp == 'yes':
|
238 |
# query_exp = paraphrase(f"question2question: {query}")
|
@@ -278,19 +298,21 @@ def run_query(query):
|
|
278 |
</div>
|
279 |
</div>
|
280 |
""", unsafe_allow_html=True)
|
|
|
281 |
if use_reranking == 'yes':
|
282 |
sentence_pairs = [[query, context] for context in contexts]
|
283 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
284 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
285 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
286 |
-
context = '\n'.join(sorted_contexts[:context_limit])
|
287 |
else:
|
288 |
-
context = '\n'.join(contexts[:context_limit])
|
289 |
|
290 |
results = []
|
291 |
model_results = qa_model(question=query, context=context, top_k=10)
|
292 |
for result in model_results:
|
293 |
-
|
|
|
294 |
if not support:
|
295 |
continue
|
296 |
results.append({
|
@@ -316,10 +338,9 @@ def run_query(query):
|
|
316 |
)
|
317 |
|
318 |
for r in sorted_result:
|
319 |
-
answer = r["answer"]
|
320 |
ctx = remove_html(r["context"])
|
321 |
for answer in r['texts']:
|
322 |
-
ctx = ctx.replace(answer, f"<mark>{answer}</mark>")
|
323 |
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
324 |
title = r.get("title", '')
|
325 |
score = round(round(r["score"], 4) * 100, 2)
|
|
|
6 |
import string
|
7 |
from streamlit.components.v1 import html
|
8 |
from sentence_transformers.cross_encoder import CrossEncoder as CE
|
9 |
+
import re
|
10 |
from typing import List, Tuple
|
11 |
import torch
|
12 |
|
|
|
26 |
def remove_html(x):
|
27 |
soup = BeautifulSoup(x, 'html.parser')
|
28 |
text = soup.get_text()
|
29 |
+
return text.strip()
|
30 |
|
31 |
|
32 |
# 4 searches: strict y/n, supported y/n
|
|
|
58 |
except:
|
59 |
pass
|
60 |
|
61 |
+
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']]
|
62 |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
|
63 |
for doc in req.json()['hits']]
|
64 |
|
|
|
85 |
)
|
86 |
|
87 |
|
88 |
+
def find_source(text, docs, matched):
|
89 |
for doc in docs:
|
90 |
for snippet in doc[1]:
|
91 |
if text in remove_html(snippet.get('snippet', '')):
|
92 |
+
if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip():
|
93 |
+
continue
|
94 |
new_text = text
|
95 |
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
|
96 |
if text in sent:
|
|
|
100 |
'text': new_text,
|
101 |
'from': snippet['source'],
|
102 |
'supporting': snippet['target'],
|
103 |
+
'source_title': remove_html(doc[2] or ''),
|
104 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
105 |
}
|
106 |
if text in remove_html(doc[3]):
|
107 |
+
if matched and remove_html(doc[3]).strip() != matched.strip():
|
108 |
+
continue
|
109 |
new_text = text
|
110 |
for sent in nltk.sent_tokenize(remove_html(doc[3])):
|
111 |
if text in sent:
|
|
|
115 |
'text': new_text,
|
116 |
'from': doc[0],
|
117 |
'supporting': doc[0],
|
118 |
+
'source_title': "ABSTRACT of " + remove_html(doc[2] or ''),
|
119 |
'source_link': f"https://scite.ai/reports/{doc[0]}"
|
120 |
}
|
121 |
return None
|
|
|
237 |
return list(result_groups.values())
|
238 |
|
239 |
|
240 |
+
def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
241 |
+
# find seperators to identify start and end
|
242 |
+
doc_starts = [0]
|
243 |
+
for match in re.finditer(seperator, contexts_string):
|
244 |
+
doc_starts.append(match.end())
|
245 |
+
|
246 |
+
for i in range(len(doc_starts)):
|
247 |
+
if i == len(doc_starts) - 1:
|
248 |
+
if start_i >= doc_starts[i]:
|
249 |
+
return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '')
|
250 |
+
|
251 |
+
if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]:
|
252 |
+
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
253 |
+
return None
|
254 |
+
|
255 |
+
|
256 |
def run_query(query):
|
257 |
# if use_query_exp == 'yes':
|
258 |
# query_exp = paraphrase(f"question2question: {query}")
|
|
|
298 |
</div>
|
299 |
</div>
|
300 |
""", unsafe_allow_html=True)
|
301 |
+
|
302 |
if use_reranking == 'yes':
|
303 |
sentence_pairs = [[query, context] for context in contexts]
|
304 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
305 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
306 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
307 |
+
context = '\n---'.join(sorted_contexts[:context_limit])
|
308 |
else:
|
309 |
+
context = '\n---'.join(contexts[:context_limit])
|
310 |
|
311 |
results = []
|
312 |
model_results = qa_model(question=query, context=context, top_k=10)
|
313 |
for result in model_results:
|
314 |
+
matched = matched_context(result['start'], result['end'], context)
|
315 |
+
support = find_source(result['answer'], orig_docs, matched)
|
316 |
if not support:
|
317 |
continue
|
318 |
results.append({
|
|
|
338 |
)
|
339 |
|
340 |
for r in sorted_result:
|
|
|
341 |
ctx = remove_html(r["context"])
|
342 |
for answer in r['texts']:
|
343 |
+
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
344 |
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
345 |
title = r.get("title", '')
|
346 |
score = round(round(r["score"], 4) * 100, 2)
|