import streamlit as st
from transformers import pipeline
import re

def custom_sentence_splitter(text):
    # Simple regex to split sentences by periods, exclamations, or question marks followed by a space
    return re.split(r'(?<=[.!?]) +', text)

st.title('Hugging Face BERT Summarizer')

# List of models
models = ["sshleifer/distilbart-cnn-12-6", "facebook/bart-large-cnn", "t5-base", "t5-large", "google/pegasus-newsroom"]

# Dropdown model selector
model = st.sidebar.selectbox("Choose a model", models)

uploaded_file = st.file_uploader("Choose a .txt file", type="txt")

# Add text input for keywords
keywords = st.text_input("Enter keywords (comma-separated)")

# Add slider to the sidebar for the scale value
scale_percentage = st.sidebar.slider('Scale %', min_value=1, max_value=100, value=50)

# Add slider for the chunk size
chunk_size = st.sidebar.slider('Chunk size (words)', min_value=100, max_value=1000, value=500)

if uploaded_file is not None and keywords:
    user_input = uploaded_file.read().decode('utf-8')
    keywords = [keyword.strip() for keyword in keywords.split(",")]

    # Split text into sentences using the custom function
    sentences = custom_sentence_splitter(user_input)
    
    # Filter sentences based on keywords
    filtered_sentences = [sentence for sentence in sentences if any(keyword.lower() in sentence.lower() for keyword in keywords)]
    filtered_text = ' '.join(filtered_sentences)

    if st.button('Summarize'):
        summarizer = pipeline('summarization', model=model)
        summarized_text = ""

        # Split filtered text into chunks by words
        words = filtered_text.split()
        chunks = [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
        
        # Summarize each chunk
        for chunk in chunks:
            chunk_length = len(chunk.split())
            min_length_percentage = max(scale_percentage - 10, 1)
            max_length_percentage = min(scale_percentage + 10, 100)
            min_length = max(int(chunk_length * min_length_percentage / 100), 1)
            max_length = int(chunk_length * max_length_percentage / 100)
            summarized = summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)
            summarized_text += summarized[0]['summary_text'] + " "

        st.text_area('Summarized Text', summarized_text, height=200)