Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import tempfile | |
from typing import List | |
from unified_document_processor import UnifiedDocumentProcessor, CustomEmbeddingFunction | |
import chromadb | |
from chromadb.config import Settings | |
from groq import Groq | |
def initialize_session_state(): | |
"""Initialize all session state variables""" | |
if 'CHROMADB_DIR' not in st.session_state: | |
st.session_state.CHROMADB_DIR = os.path.join(os.getcwd(), 'chromadb_data') | |
os.makedirs(st.session_state.CHROMADB_DIR, exist_ok=True) | |
if 'processed_files' not in st.session_state: | |
st.session_state.processed_files = set() | |
if 'processor' not in st.session_state: | |
try: | |
st.session_state.processor = None # Will be initialized in StreamlitDocProcessor | |
except Exception as e: | |
st.error(f"Error initializing processor: {str(e)}") | |
class StreamlitDocProcessor: | |
def __init__(self): | |
if st.session_state.processor is None: | |
try: | |
groq_api_key = st.secrets["GROQ_API_KEY"] | |
# Initialize processor with persistent ChromaDB | |
st.session_state.processor = self.initialize_processor(groq_api_key) | |
# Update processed files after initializing processor | |
st.session_state.processed_files = self.get_processed_files() | |
except Exception as e: | |
st.error(f"Error initializing processor: {str(e)}") | |
return | |
def initialize_processor(self, groq_api_key): | |
"""Initialize the processor with persistent ChromaDB""" | |
class PersistentUnifiedDocumentProcessor(UnifiedDocumentProcessor): | |
def __init__(self, api_key, collection_name="unified_content", persist_dir=None): | |
self.groq_client = Groq(api_key=api_key) | |
self.max_elements_per_chunk = 50 | |
self.pdf_chunk_size = 500 | |
self.pdf_overlap = 50 | |
self._initialize_nltk() | |
# Initialize persistent ChromaDB | |
self.chroma_client = chromadb.PersistentClient( | |
path=persist_dir, | |
settings=Settings( | |
allow_reset=True, | |
is_persistent=True | |
) | |
) | |
# Get or create collection | |
try: | |
self.collection = self.chroma_client.get_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
except: | |
self.collection = self.chroma_client.create_collection( | |
name=collection_name, | |
embedding_function=CustomEmbeddingFunction() | |
) | |
return PersistentUnifiedDocumentProcessor( | |
groq_api_key, | |
persist_dir=st.session_state.CHROMADB_DIR | |
) | |
def get_processed_files(self) -> set: | |
"""Get list of processed files from ChromaDB""" | |
try: | |
if st.session_state.processor: | |
available_files = st.session_state.processor.get_available_files() | |
return set(available_files['pdf'] + available_files['xml']) | |
return set() | |
except Exception as e: | |
st.error(f"Error getting processed files: {str(e)}") | |
return set() | |
def run(self): | |
st.title("AAS Assistant") | |
# Create sidebar for navigation | |
page = st.sidebar.selectbox( | |
"Choose a page", | |
["Upload & Process", "Query"] | |
) | |
if page == "Upload & Process": | |
self.upload_and_process_page() | |
else: | |
self.qa_page() | |
def upload_and_process_page(self): | |
st.header("Upload and Process Documents") | |
# File uploader | |
uploaded_files = st.file_uploader( | |
"Upload PDF or XML files", | |
type=['pdf', 'xml'], | |
accept_multiple_files=True | |
) | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
# Create progress bar | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
if uploaded_file.name not in st.session_state.processed_files: | |
try: | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
tmp_file.write(uploaded_file.getbuffer()) | |
temp_path = tmp_file.name | |
# Process the file | |
status_text.text(f'Processing {uploaded_file.name}...') | |
progress_bar.progress(25) | |
result = st.session_state.processor.process_file(temp_path) | |
progress_bar.progress(75) | |
if result['success']: | |
st.session_state.processed_files.add(uploaded_file.name) | |
progress_bar.progress(100) | |
status_text.success(f"Successfully processed {uploaded_file.name}") | |
else: | |
progress_bar.progress(100) | |
status_text.error(f"Failed to process {uploaded_file.name}: {result['error']}") | |
except Exception as e: | |
status_text.error(f"Error processing {uploaded_file.name}: {str(e)}") | |
finally: | |
# Clean up temporary file | |
try: | |
os.unlink(temp_path) | |
except: | |
pass | |
else: | |
status_text.info(f"{uploaded_file.name} has already been processed") | |
progress_bar.progress(100) | |
# Display processed files | |
if st.session_state.processed_files: | |
st.subheader("Processed Files") | |
for file in sorted(st.session_state.processed_files): | |
st.text(f"β {file}") | |
def qa_page(self): | |
st.header("Query our database") | |
try: | |
# Refresh available files | |
st.session_state.processed_files = self.get_processed_files() | |
if not st.session_state.processed_files: | |
st.warning("No processed files available. Please upload and process some files first.") | |
return | |
# File selection | |
selected_files = st.multiselect( | |
"Select files to search through", | |
sorted(list(st.session_state.processed_files)), | |
default=list(st.session_state.processed_files) | |
) | |
if not selected_files: | |
st.warning("Please select at least one file to search through.") | |
return | |
# Question input | |
question = st.text_input("Enter your question:") | |
if st.button("Ask Question") and question: | |
try: | |
with st.spinner("Searching for answer..."): | |
answer = st.session_state.processor.ask_question_selective( | |
question, | |
selected_files | |
) | |
st.write("Answer:", answer) | |
except Exception as e: | |
st.error(f"Error getting answer: {str(e)}") | |
except Exception as e: | |
st.error(f"Error in Q&A interface: {str(e)}") | |
def main(): | |
# Initialize session state | |
initialize_session_state() | |
# Create and run app | |
app = StreamlitDocProcessor() | |
app.run() | |
if __name__ == "__main__": | |
main() |