import streamlit as st |
import text_transformation_tools as ttt |
from transformers import pipeline |
import plotly.express as px |
def read_pdf(file): |
text = ttt.pdf_to_text(uploaded_file) |
return text |
def analyze_text(paragraphs, topics, model, mode, min_chars, prob): |
with st.spinner('Loading model'): |
classifier = pipeline('zero-shot-classification', model=model) |
relevant_parts = {} |
for topic in topics: |
relevant_parts[topic] = [] |
if mode == 'paragraphs': |
text = paragraphs |
elif mode == 'sentences': |
text = [] |
for paragraph in paragraphs: |
for sentence in paragraph.split('.'): |
text.append(sentence) |
min_chars = min_chars |
min_score = prob |
with st.spinner('Analyzing text...'): |
counter = 0 |
counter_rel = 0 |
counter_tot = len(text) |
with st.empty(): |
for sequence_to_classify in text: |
cleansed_sequence = sequence_to_classify.replace('\n', '').replace(' ', ' ') |
if len(cleansed_sequence) >= min_chars: |
classified = classifier(cleansed_sequence, topics, multi_label=True) |
for idx in range(len(classified['scores'])): |
if classified['scores'][idx] >= min_score: |
relevant_parts[classified['labels'][idx]].append(sequence_to_classify) |
counter_rel += 1 |
counter += 1 |
st.write('Analyzed {} of {} {}. Found {} relevant {} so far.'.format(counter, counter_tot, mode, counter_rel, mode)) |
return relevant_parts |
'facebook/bart-large-mnli': 'bart-large-mnli (very slow, english)', |
'valhalla/distilbart-mnli-12-1': 'distilbart-mnli-12-1 (slow, english)', |
'BaptisteDoyen/camembert-base-xnli': 'camembert-base-xnli (fast, english)', |
'typeform/mobilebert-uncased-mnli': 'mobilebert-uncased-mnli (very fast, english)', |
'Sahajtomar/German_Zeroshot': 'German_Zeroshot (slow, german)', |
'MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7': 'mDeBERTa-v3-base-xnli-multilingual-nli-2mil7 (fast, multilingual)'} |
def format_func(option): |
return CHOICES[option] |
st.header('File and topics') |
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") |
topics = st.text_input(label='Enter coma separated sustainability topics of interest.', value = 'human rights, sustainability') |
st.header('Settings') |
col1, col2 = st.columns(2) |
with col1: |
model = st.selectbox("Select model used to analyze pdf.", options=list(CHOICES.keys()), format_func=format_func, index=3) |
mode = st.selectbox(label='Chose if you want to detect relevant paragraphs or sentences.', options=['paragraphs', 'sentences']) |
with col2: |
min_chars = st.number_input(label='Minimum number of characters to analyze in a text', min_value=0, max_value=500, value=20) |
probability = st.number_input(label='Minimum probability of being relevant to accept (in percent)', min_value=0, max_value=100, value=90)/100 |
topics = topics.split(',') |
topics = [topic.strip() for topic in topics] |
st.header('Analyze PDF') |
if st.button('Analyze PDF'): |
with st.spinner('Reading PDF...'): |
text = read_pdf(uploaded_file) |
page_count = ttt.count_pages(uploaded_file) |
language = ttt.detect_language(' '.join(text))[0] |
st.subheader('Overview') |
st.write('Our pdf reader detected {} pages and {} paragraphs. We assume that the language of this text is "{}".'.format(page_count, len(text), language)) |
st.subheader('Analysis') |
relevant_parts = analyze_text(text, topics, model, mode, min_chars, probability) |
counts = [len(relevant_parts[topic]) for topic in topics] |
fig = px.bar(x=topics, y=counts, title='Found {}s of Relevance'.format(mode)) |
st.plotly_chart(fig) |
st.subheader('Relevant Passages') |
st.write(relevant_parts) |