testitest / app.py
crocidoc's picture
initial commit
cc83a1d
raw
history blame
3.95 kB
import streamlit as st
import text_transformation_tools as ttt
from transformers import pipeline
import plotly.express as px
def read_pdf(file):
text = ttt.pdf_to_text(uploaded_file)
return text
def analyze_text(paragraphs, topics, model, mode, min_chars, prob):
with st.spinner('Loading model'):
classifier = pipeline('zero-shot-classification', model=model)
relevant_parts = {}
for topic in topics:
relevant_parts[topic] = []
if mode == 'paragraphs':
text = paragraphs
elif mode == 'sentences':
text = []
for paragraph in paragraphs:
for sentence in paragraph.split('.'):
text.append(sentence)
min_chars = min_chars
min_score = prob
with st.spinner('Analyzing text...'):
counter = 0
counter_rel = 0
counter_tot = len(text)
with st.empty():
for sequence_to_classify in text:
cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ')
if len(cleansed_sequence) >= min_chars:
classified = classifier(cleansed_sequence, topics, multi_label=True)
for idx in range(len(classified['scores'])):
if classified['scores'][idx] >= min_score:
relevant_parts[classified['labels'][idx]].append(sequence_to_classify)
counter_rel += 1
counter += 1
st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode))
return relevant_parts
CHOICES = {
'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)',
'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)',
'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)',
'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)',
'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)',
'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'}
def format_func(option):
return CHOICES[option]
st.header('File and topics')
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability')
st.header('Settings')
col1, col2 = st.columns(2)
with col1:
model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3)
mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences'])
with col2:
min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20)
probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100
topics = topics.split(',')
topics = [topic.strip() for topic in topics]
st.header('Analyze PDF')
if st.button('Analyze PDF'):
with st.spinner('Reading PDF...'):
text = read_pdf(uploaded_file)
page_count = ttt.count_pages(uploaded_file)
language = ttt.detect_language(' '.join(text))[0]
st.subheader('Overview')
st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language))
st.subheader('Analysis')
relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability)
counts = [len(relevant_parts[topic]) for topic in topics]
fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode))
st.plotly_chart(fig)
st.subheader('Relevant Passages')
st.write(relevant_parts)