File size: 5,776 Bytes
a560ed2 f5f9092 a560ed2 f5f9092 a560ed2 d9522e8 a560ed2 e2f1c79 b0154fd e2f1c79 a560ed2 e2f1c79 a560ed2 e2f1c79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import streamlit as st
from predict import run_prediction
from io import StringIO
import json
import spacy
from spacy import displacy
from transformers import AutoTokenizer, AutoModelForTokenClassification,RobertaTokenizer,pipeline
import torch
import nltk
from nltk.tokenize import sent_tokenize
from fin_readability_sustainability import BERTClass, do_predict
import pandas as pd
import en_core_web_sm
nlp = en_core_web_sm.load()
nltk.download('punkt')
#nlp = spacy.load("en_core_web_sm")
st.set_page_config(layout="wide")
st.cache(show_spinner=False, persist=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#SUSTAIN STARTS
tokenizer_sus = RobertaTokenizer.from_pretrained('roberta-base')
model_sustain = BERTClass(2, "sustanability")
model_sustain.to(device)
model_sustain.load_state_dict(torch.load('sustainability_model.bin', map_location=device)['model_state_dict'])
def get_sustainability(text):
df = pd.DataFrame({'sentence':sent_tokenize(text)})
actual_predictions_sustainability = do_predict(model_sustain, tokenizer_sus, df)
highlight = []
for sent, prob in zip(df['sentence'].values, actual_predictions_sustainability[1]):
if prob>=4.384316:
highlight.append((sent, 'non-sustainable'))
elif prob<=1.423736:
highlight.append((sent, 'sustainable'))
else:
highlight.append((sent, '-'))
return highlight
#SUSTAIN ENDS
##Summarization
def summarize_text(text):
summarizer = pipeline("summarization", model="knkarthick/MEETING_SUMMARY")
resp = summarizer(text)
stext = resp[0]['summary_text']
return stext
##Forward Looking Statement
#def fls(text):
# fls_model = pipeline("text-classification", model="yiyanghkust/finbert-fls", tokenizer="yiyanghkust/finbert-fls")
# results = fls_model(split_in_sentences(text))
#return make_spans(text,results)
##Company Extraction
#ner=pipeline('ner',model='Jean-Baptiste/camembert-ner-with-dates',tokenizer='Jean-Baptiste/camembert-ner-with-dates', aggregation_strategy="simple")
#def fin_ner(text):
#replaced_spans = ner(text)
# return replaced_spans
def load_questions():
questions = []
with open('questions.txt') as f:
questions = f.readlines()
return questions
def load_questions_short():
questions_short = []
with open('questionshort.txt') as f:
questions_short = f.readlines()
return questions_short
st.cache(show_spinner=False, persist=True)
questions = load_questions()
questions_short = load_questions_short()
### DEFINE SIDEBAR
st.sidebar.title("Interactive Contract Analysis")
st.sidebar.header('CONTRACT UPLOAD')
with open('NDA1.txt') as f:
contract_data = f.read()
# upload contract
user_upload = st.sidebar.file_uploader('Please upload your contract', type=['txt'],
accept_multiple_files=False)
# process upload
if user_upload is not None:
print(user_upload.name, user_upload.type)
extension = user_upload.name.split('.')[-1].lower()
if extension == 'txt':
print('text file uploaded')
# To convert to a string based IO:
stringio = StringIO(user_upload.getvalue().decode("utf-8"))
# To read file as string:
contract_data = stringio.read()
else:
st.warning('Unknown uploaded file type, please try again')
results_drop = ['1', '2', '3']
number_results = st.sidebar.selectbox('Select number of results', results_drop)
### DEFINE MAIN PAGE
st.header("Legal Contract Review Demo")
paragraph = st.text_area(label="Contract", value=contract_data, height=300)
questions_drop = questions_short
question_short = st.selectbox('Choose one of the 41 queries from the CUAD dataset:', questions_drop)
idxq = questions_drop.index(question_short)
question = questions[idxq]
raw_answer=""
if st.button('Analyze'):
if (not len(paragraph)==0) and not (len(question)==0):
print('getting predictions')
with st.spinner(text='Analysis in progress...'):
predictions = run_prediction([question], paragraph, 'marshmellow77/roberta-base-cuad',
n_best_size=5)
answer = ""
if predictions['0'] == "":
answer = 'No answer found in document'
else:
# if number_results == '1':
# answer = f"Answer: {predictions['0']}"
# # st.text_area(label="Answer", value=f"{answer}")
# else:
answer = ""
with open("nbest.json") as jf:
data = json.load(jf)
for i in range(int(number_results)):
raw_answer=data['0'][i]['text']
answer += f"Answer {i+1}: {data['0'][i]['text']} -- \n"
answer += f"Probability: {round(data['0'][i]['probability']*100,1)}%\n\n"
st.success(answer)
st.write(get_sustainability(raw_answer))
st.write(summarize_text(raw_answer))
doc = nlp(raw_answer)
st.write(displacy.render(doc, style="ent"))
else:
st.write("Unable to call model, please select question and contract")
#if st.button('Check Sustainability'):
# if(raw_answer==""):
# st.write("Unable to call model, please select question and contract")
# else:
# st.write(get_sustainability(raw_answer))
#if st.button('Summarize'):
# if(raw_answer==""):
# st.write("Unable to call model, please select question and contract")
# else:
# st.write(summarize_text(raw_answer))
#if st.button('NER'):
# if(raw_answer==""):
# st.write("Unable to call model, please select question and contract")
# else:
# doc = nlp(raw_answer)
# st.write(displacy.render(doc, style="ent")) |