Smart_AAS / app.py
TahaRasouli's picture
Update app.py
74ff41b verified
import streamlit as st
import os
import tempfile
from typing import List
from unified_document_processor import UnifiedDocumentProcessor, CustomEmbeddingFunction
import chromadb
from chromadb.config import Settings
from groq import Groq
def initialize_session_state():
"""Initialize all session state variables"""
if 'CHROMADB_DIR' not in st.session_state:
st.session_state.CHROMADB_DIR = os.path.join(os.getcwd(), 'chromadb_data')
os.makedirs(st.session_state.CHROMADB_DIR, exist_ok=True)
if 'processed_files' not in st.session_state:
st.session_state.processed_files = set()
if 'processor' not in st.session_state:
try:
st.session_state.processor = None # Will be initialized in StreamlitDocProcessor
except Exception as e:
st.error(f"Error initializing processor: {str(e)}")
class StreamlitDocProcessor:
def __init__(self):
if st.session_state.processor is None:
try:
groq_api_key = st.secrets["GROQ_API_KEY"]
# Initialize processor with persistent ChromaDB
st.session_state.processor = self.initialize_processor(groq_api_key)
# Update processed files after initializing processor
st.session_state.processed_files = self.get_processed_files()
except Exception as e:
st.error(f"Error initializing processor: {str(e)}")
return
def initialize_processor(self, groq_api_key):
"""Initialize the processor with persistent ChromaDB"""
class PersistentUnifiedDocumentProcessor(UnifiedDocumentProcessor):
def __init__(self, api_key, collection_name="unified_content", persist_dir=None):
self.groq_client = Groq(api_key=api_key)
self.max_elements_per_chunk = 50
self.pdf_chunk_size = 500
self.pdf_overlap = 50
self._initialize_nltk()
# Initialize persistent ChromaDB
self.chroma_client = chromadb.PersistentClient(
path=persist_dir,
settings=Settings(
allow_reset=True,
is_persistent=True
)
)
# Get or create collection
try:
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
except:
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
return PersistentUnifiedDocumentProcessor(
groq_api_key,
persist_dir=st.session_state.CHROMADB_DIR
)
def get_processed_files(self) -> set:
"""Get list of processed files from ChromaDB"""
try:
if st.session_state.processor:
available_files = st.session_state.processor.get_available_files()
return set(available_files['pdf'] + available_files['xml'])
return set()
except Exception as e:
st.error(f"Error getting processed files: {str(e)}")
return set()
def run(self):
st.title("AAS Assistant")
# Create sidebar for navigation
page = st.sidebar.selectbox(
"Choose a page",
["Upload & Process", "Query"]
)
if page == "Upload & Process":
self.upload_and_process_page()
else:
self.qa_page()
def upload_and_process_page(self):
st.header("Upload and Process Documents")
# File uploader
uploaded_files = st.file_uploader(
"Upload PDF or XML files",
type=['pdf', 'xml'],
accept_multiple_files=True
)
if uploaded_files:
for uploaded_file in uploaded_files:
# Create progress bar
progress_bar = st.progress(0)
status_text = st.empty()
if uploaded_file.name not in st.session_state.processed_files:
try:
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getbuffer())
temp_path = tmp_file.name
# Process the file
status_text.text(f'Processing {uploaded_file.name}...')
progress_bar.progress(25)
result = st.session_state.processor.process_file(temp_path)
progress_bar.progress(75)
if result['success']:
st.session_state.processed_files.add(uploaded_file.name)
progress_bar.progress(100)
status_text.success(f"Successfully processed {uploaded_file.name}")
else:
progress_bar.progress(100)
status_text.error(f"Failed to process {uploaded_file.name}: {result['error']}")
except Exception as e:
status_text.error(f"Error processing {uploaded_file.name}: {str(e)}")
finally:
# Clean up temporary file
try:
os.unlink(temp_path)
except:
pass
else:
status_text.info(f"{uploaded_file.name} has already been processed")
progress_bar.progress(100)
# Display processed files
if st.session_state.processed_files:
st.subheader("Processed Files")
for file in sorted(st.session_state.processed_files):
st.text(f"βœ“ {file}")
def qa_page(self):
st.header("Query our database")
try:
# Refresh available files
st.session_state.processed_files = self.get_processed_files()
if not st.session_state.processed_files:
st.warning("No processed files available. Please upload and process some files first.")
return
# File selection
selected_files = st.multiselect(
"Select files to search through",
sorted(list(st.session_state.processed_files)),
default=list(st.session_state.processed_files)
)
if not selected_files:
st.warning("Please select at least one file to search through.")
return
# Question input
question = st.text_input("Enter your question:")
if st.button("Ask Question") and question:
try:
with st.spinner("Searching for answer..."):
answer = st.session_state.processor.ask_question_selective(
question,
selected_files
)
st.write("Answer:", answer)
except Exception as e:
st.error(f"Error getting answer: {str(e)}")
except Exception as e:
st.error(f"Error in Q&A interface: {str(e)}")
def main():
# Initialize session state
initialize_session_state()
# Create and run app
app = StreamlitDocProcessor()
app.run()
if __name__ == "__main__":
main()