Spaces:
Sleeping
Sleeping
File size: 7,972 Bytes
35ecede a79ad25 35ecede a79ad25 f294a04 35ecede f294a04 a79ad25 f294a04 a79ad25 f294a04 a79ad25 35ecede 74ff41b 35ecede 74ff41b 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede 74ff41b 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede a79ad25 35ecede f294a04 35ecede |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import streamlit as st
import os
import tempfile
from typing import List
from unified_document_processor import UnifiedDocumentProcessor, CustomEmbeddingFunction
import chromadb
from chromadb.config import Settings
from groq import Groq
def initialize_session_state():
"""Initialize all session state variables"""
if 'CHROMADB_DIR' not in st.session_state:
st.session_state.CHROMADB_DIR = os.path.join(os.getcwd(), 'chromadb_data')
os.makedirs(st.session_state.CHROMADB_DIR, exist_ok=True)
if 'processed_files' not in st.session_state:
st.session_state.processed_files = set()
if 'processor' not in st.session_state:
try:
st.session_state.processor = None # Will be initialized in StreamlitDocProcessor
except Exception as e:
st.error(f"Error initializing processor: {str(e)}")
class StreamlitDocProcessor:
def __init__(self):
if st.session_state.processor is None:
try:
groq_api_key = st.secrets["GROQ_API_KEY"]
# Initialize processor with persistent ChromaDB
st.session_state.processor = self.initialize_processor(groq_api_key)
# Update processed files after initializing processor
st.session_state.processed_files = self.get_processed_files()
except Exception as e:
st.error(f"Error initializing processor: {str(e)}")
return
def initialize_processor(self, groq_api_key):
"""Initialize the processor with persistent ChromaDB"""
class PersistentUnifiedDocumentProcessor(UnifiedDocumentProcessor):
def __init__(self, api_key, collection_name="unified_content", persist_dir=None):
self.groq_client = Groq(api_key=api_key)
self.max_elements_per_chunk = 50
self.pdf_chunk_size = 500
self.pdf_overlap = 50
self._initialize_nltk()
# Initialize persistent ChromaDB
self.chroma_client = chromadb.PersistentClient(
path=persist_dir,
settings=Settings(
allow_reset=True,
is_persistent=True
)
)
# Get or create collection
try:
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
except:
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
return PersistentUnifiedDocumentProcessor(
groq_api_key,
persist_dir=st.session_state.CHROMADB_DIR
)
def get_processed_files(self) -> set:
"""Get list of processed files from ChromaDB"""
try:
if st.session_state.processor:
available_files = st.session_state.processor.get_available_files()
return set(available_files['pdf'] + available_files['xml'])
return set()
except Exception as e:
st.error(f"Error getting processed files: {str(e)}")
return set()
def run(self):
st.title("AAS Assistant")
# Create sidebar for navigation
page = st.sidebar.selectbox(
"Choose a page",
["Upload & Process", "Query"]
)
if page == "Upload & Process":
self.upload_and_process_page()
else:
self.qa_page()
def upload_and_process_page(self):
st.header("Upload and Process Documents")
# File uploader
uploaded_files = st.file_uploader(
"Upload PDF or XML files",
type=['pdf', 'xml'],
accept_multiple_files=True
)
if uploaded_files:
for uploaded_file in uploaded_files:
# Create progress bar
progress_bar = st.progress(0)
status_text = st.empty()
if uploaded_file.name not in st.session_state.processed_files:
try:
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getbuffer())
temp_path = tmp_file.name
# Process the file
status_text.text(f'Processing {uploaded_file.name}...')
progress_bar.progress(25)
result = st.session_state.processor.process_file(temp_path)
progress_bar.progress(75)
if result['success']:
st.session_state.processed_files.add(uploaded_file.name)
progress_bar.progress(100)
status_text.success(f"Successfully processed {uploaded_file.name}")
else:
progress_bar.progress(100)
status_text.error(f"Failed to process {uploaded_file.name}: {result['error']}")
except Exception as e:
status_text.error(f"Error processing {uploaded_file.name}: {str(e)}")
finally:
# Clean up temporary file
try:
os.unlink(temp_path)
except:
pass
else:
status_text.info(f"{uploaded_file.name} has already been processed")
progress_bar.progress(100)
# Display processed files
if st.session_state.processed_files:
st.subheader("Processed Files")
for file in sorted(st.session_state.processed_files):
st.text(f"✓ {file}")
def qa_page(self):
st.header("Query our database")
try:
# Refresh available files
st.session_state.processed_files = self.get_processed_files()
if not st.session_state.processed_files:
st.warning("No processed files available. Please upload and process some files first.")
return
# File selection
selected_files = st.multiselect(
"Select files to search through",
sorted(list(st.session_state.processed_files)),
default=list(st.session_state.processed_files)
)
if not selected_files:
st.warning("Please select at least one file to search through.")
return
# Question input
question = st.text_input("Enter your question:")
if st.button("Ask Question") and question:
try:
with st.spinner("Searching for answer..."):
answer = st.session_state.processor.ask_question_selective(
question,
selected_files
)
st.write("Answer:", answer)
except Exception as e:
st.error(f"Error getting answer: {str(e)}")
except Exception as e:
st.error(f"Error in Q&A interface: {str(e)}")
def main():
# Initialize session state
initialize_session_state()
# Create and run app
app = StreamlitDocProcessor()
app.run()
if __name__ == "__main__":
main() |