cpnepo commited on
Commit
c62cc25
1 Parent(s): 4537dbc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import (AutoModelForQuestionAnswering,
4
+ AutoTokenizer, pipeline)
5
+
6
+ import pandas as pd
7
+ import regex as re
8
+
9
+ # Select model for question answering
10
+ model_name = "deepset/roberta-base-squad2"
11
+
12
+ # Load model & tokenizer
13
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # Create pipeline
17
+ pipe = pipeline('question-answering', model=model_name, tokenizer=model_name)
18
+
19
+ # Load DFA Press Release dataset
20
+ df = pd.read_csv('../data/dfa_pr_v4_cleaned.csv')
21
+
22
+ # Group into 6 sentences-long parts
23
+ partitions = df['article'].values.tolist()
24
+
25
+ st.title('DFA Question Answering Model')
26
+
27
+ # Type in HP-related query here
28
+ query = st.text_area("Type in your question below:")
29
+
30
+ if st.button('Searching for the answer...'):
31
+ # Perform sentence embedding on query and sentence groups
32
+ model_embed_name = 'sentence-transformers/msmarco-distilbert-dot-v5'
33
+
34
+ model_embed = SentenceTransformer(model_embed_name)
35
+ doc_emb = model_embed.encode(partitions)
36
+ query_emb = model_embed.encode(query)
37
+
38
+ #Compute dot score between query and all document embeddings
39
+ scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
40
+
41
+ #Combine docs & scores
42
+ doc_score_pairs = list(zip(partitions, scores))
43
+
44
+ #Sort by decreasing score and get only 3 most similar groups
45
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1],
46
+ reverse=True)[:1]
47
+
48
+ # Join these similar groups to form the context
49
+ context = "".join(x[0] for x in doc_score_pairs)
50
+
51
+ # Perform the querying
52
+ QA_input = {'question': query, 'context': context}
53
+ res = pipe(QA_input)
54
+
55
+ confidence = res.get('score')
56
+ if confidence > 0.8:
57
+ st.write(res.get('answer'))
58
+ st.write(context)
59
+ else:
60
+ out = "I am not sure."
61
+ st.write(out)
62
+ st.write(context)
63
+ #out = res.get('answer')