|
import streamlit as st |
|
import text_transformation_tools as ttt |
|
from transformers import pipeline |
|
import plotly.express as px |
|
|
|
|
|
def read_pdf(file): |
|
text = ttt.pdf_to_text(uploaded_file) |
|
|
|
return text |
|
|
|
def analyze_text(paragraphs, topics, model, mode, min_chars, prob): |
|
|
|
with st.spinner('Loading model'): |
|
classifier = pipeline('zero-shot-classification', model=model) |
|
|
|
relevant_parts = {} |
|
|
|
for topic in topics: |
|
relevant_parts[topic] = [] |
|
|
|
if mode == 'paragraphs': |
|
text = paragraphs |
|
elif mode == 'sentences': |
|
text = [] |
|
for paragraph in paragraphs: |
|
for sentence in paragraph.split('.'): |
|
text.append(sentence) |
|
|
|
min_chars = min_chars |
|
min_score = prob |
|
|
|
with st.spinner('Analyzing text...'): |
|
counter = 0 |
|
counter_rel = 0 |
|
counter_tot = len(text) |
|
|
|
with st.empty(): |
|
|
|
for sequence_to_classify in text: |
|
|
|
cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ') |
|
|
|
if len(cleansed_sequence) >= min_chars: |
|
|
|
|
|
classified = classifier(cleansed_sequence, topics, multi_label=True) |
|
|
|
for idx in range(len(classified['scores'])): |
|
if classified['scores'][idx] >= min_score: |
|
relevant_parts[classified['labels'][idx]].append(sequence_to_classify) |
|
counter_rel += 1 |
|
|
|
counter += 1 |
|
|
|
st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode)) |
|
|
|
|
|
return relevant_parts |
|
|
|
|
|
CHOICES = { |
|
'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)', |
|
'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)', |
|
'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)', |
|
'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)', |
|
'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)', |
|
'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'} |
|
def format_func(option): |
|
return CHOICES[option] |
|
|
|
st.header('File and topics') |
|
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") |
|
topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability') |
|
|
|
|
|
st.header('Settings') |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3) |
|
mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences']) |
|
with col2: |
|
min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20) |
|
probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100 |
|
|
|
topics = topics.split(',') |
|
topics = [topic.strip() for topic in topics] |
|
|
|
st.header('Analyze PDF') |
|
|
|
if st.button('Analyze PDF'): |
|
with st.spinner('Reading PDF...'): |
|
text = read_pdf(uploaded_file) |
|
page_count = ttt.count_pages(uploaded_file) |
|
language = ttt.detect_language(' '.join(text))[0] |
|
st.subheader('Overview') |
|
st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language)) |
|
|
|
st.subheader('Analysis') |
|
relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability) |
|
|
|
counts = [len(relevant_parts[topic]) for topic in topics] |
|
|
|
fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode)) |
|
|
|
st.plotly_chart(fig) |
|
|
|
st.subheader('Relevant Passages') |
|
st.write(relevant_parts) |
|
|
|
|