File size: 4,012 Bytes
111325b
508732a
 
 
 
 
 
 
 
a1c7b5a
508732a
b9ebf82
508732a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3a87c9
b9ebf82
508732a
 
 
 
 
a1c7b5a
508732a
f916292
508732a
b6704c3
 
 
508732a
b6704c3
508732a
 
 
2ea27b8
508732a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd2f707
508732a
 
 
 
 
a3a87c9
508732a
ee4baf9
508732a
ee4baf9
508732a
 
 
 
 
 
 
 
 
 
f916292
508732a
2ea27b8
508732a
 
2ea27b8
508732a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
from markdown import markdown
from annotated_text import annotation
import logging

from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import TfidfRetriever
from haystack.pipelines import ExtractiveQAPipeline
from haystack.nodes import FARMReader
import time
import joblib
from random import choice

@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
def create_pipeline():
    docs = joblib.load('docs.joblib')

    document_store = InMemoryDocumentStore()
    document_store.write_documents(docs)
    
    retriever = TfidfRetriever(document_store)
    reader = FARMReader(model_name_or_path="ixa-ehu/SciBERT-SQuAD-QuAC")
    
    pipeline = ExtractiveQAPipeline(reader, retriever)
    
    return pipeline
    
pipeline = create_pipeline()

def set_state_if_absent(key, value):
    if key not in st.session_state:
        st.session_state[key] = value

queries = joblib.load('queries.joblib')
set_state_if_absent("question", choice(queries))
set_state_if_absent("results", None)

def reset_results(*args):
    st.session_state.results = None

st.markdown('''# Welcome to **SRM RP explorer**!
This QA demo uses a [Haystack Extractive QA Pipeline](https://haystack.deepset.ai/components/ready-made-pipelines#extractiveqapipeline) with 
an [InMemoryDocumentStore](https://haystack.deepset.ai/components/document-store) which contains abstracts of 17k+ research papers associated with SRM university.''')

def change_query(*args):
    st.session_state.question = choice(queries)

query = st.text_input('Enter a query to get started:', value=st.session_state.question, max_chars=100, on_change=reset_results)
st.button('Random Question', on_click=change_query)

def ask_question(query):
    start = time.time()
    prediction = pipeline.run(query=query, params={"Retriever": {"top_k": 6}, "Reader": {"top_k": 3}})
    st.write('Time taken: %s s' % round(time.time()-start, 2))
    
    results = []
    for answer in prediction["answers"]:
        answer = answer.to_dict()
        if answer["answer"]:
            results.append(
                {
                    "title":answer["meta"]["name"],
                    "link":answer["meta"]["link"],
                    "context": "..." + answer["context"] + "...",
                    "answer": answer["answer"],
                    "score": round(answer["score"] * 100, 2),
                    "offset_start_in_doc": answer["offsets_in_document"][0]["start"],
                }
            )
        else:
            results.append(
                {
                    "title":None,
                    "link":None,
                    "context": None,
                    "answer": None,
                    "score": round(answer["score"] * 100, 2),
                }
            )
    return results

if query:
    with st.spinner("🧠    Performing semantic search on abstracts..."):
        try:
            msg = 'Asked ' + query
            logging.info(msg)
            st.session_state.results = ask_question(query)    
        except Exception as e:
            logging.exception(e)

if st.session_state.results:
    st.write('## Top Results')
    for count, result in enumerate(st.session_state.results):
        if result["answer"]:
            answer, context = result["answer"], result["context"]
            start_idx = context.find(answer)
            end_idx = start_idx + len(answer)
            st.markdown(f"### [{result['title']}]({result['link']})")
            st.write(
                markdown(context[:start_idx] + str(annotation(body=answer, label="RELEVANT", background="#67a17a", color='#ffffff')) + context[end_idx:]),
                unsafe_allow_html=True,
            )
            st.markdown(f"**Relevance:** {result['score']}")
        else:
            st.info(
                "🤔    Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
            )