Smart_AAS / unified_document_processor.py
TahaRasouli's picture
Create unified_document_processor.py
1abb1bd verified
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure NLTK's `punkt` tokenizer resource is available."""
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
print("Downloading NLTK 'punkt' tokenizer...")
nltk.download('punkt')
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
sentences = sent_tokenize(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def flatten_xml_to_text(self, element, depth=0) -> str:
"""Convert XML element and its children to a flat text representation"""
text_parts = []
element_info = f"Element: {element.tag}"
if element.attrib:
element_info += f", Attributes: {json.dumps(element.attrib)}"
if element.text and element.text.strip():
element_info += f", Text: {element.text.strip()}"
text_parts.append(element_info)
for child in element:
child_text = self.flatten_xml_to_text(child, depth + 1)
text_parts.append(child_text)
return "\n".join(text_parts)
def chunk_xml_text(self, text: str, max_chunk_size: int = 2000) -> List[str]:
"""Split flattened XML text into manageable chunks"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > max_chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def generate_natural_language(self, content: Union[List[Dict], str], content_type: str) -> str:
"""Generate natural language description with improved error handling and chunking"""
try:
if content_type == "xml":
prompt = f"Convert this XML structure description to a natural language summary: {content}"
else: # pdf
prompt = f"Summarize this text while preserving key information: {content}"
max_prompt_length = 4000
if len(prompt) > max_prompt_length:
prompt = prompt[:max_prompt_length] + "..."
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating natural language: {str(e)}")
if len(content) > 2000:
half_length = len(content) // 2
first_half = content[:half_length]
try:
return self.generate_natural_language(first_half, content_type)
except:
return None
return None
# Additional methods (unchanged but structured for easier review)...
def store_in_vector_db(self, natural_language: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[natural_language],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with improved chunking"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
flattened_text = self.flatten_xml_to_text(root)
chunks = self.chunk_xml_text(flattened_text)
print(f"Split XML into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing XML chunk {i+1}/{len(chunks)}")
try:
natural_language = self.generate_natural_language(chunk, "xml")
if natural_language:
metadata = {
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language'
})
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing PDF chunk {i+1}/{len(chunks)}")
natural_language = self.generate_natural_language(chunk, "pdf")
if natural_language:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language,
'original_text': chunk[:200] + "..."
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language summary'
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
context = "\n\n".join(results['documents'][0])
prompt = f"""Based on the following content from the selected files, please answer this question: {question}
Content:
{context}
Please provide a direct answer based only on the information provided above."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"
from typing import List, Dict, Union
from groq import Groq
import chromadb
import os
import datetime
import json
import xml.etree.ElementTree as ET
import nltk
from nltk.tokenize import sent_tokenize
import PyPDF2
from sentence_transformers import SentenceTransformer
class CustomEmbeddingFunction:
def __init__(self):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
class UnifiedDocumentProcessor:
def __init__(self, groq_api_key, collection_name="unified_content"):
"""Initialize the processor with necessary clients"""
self.groq_client = Groq(api_key=groq_api_key)
# XML-specific settings
self.max_elements_per_chunk = 50
# PDF-specific settings
self.pdf_chunk_size = 500
self.pdf_overlap = 50
# Initialize NLTK - Updated to handle both resources
self._initialize_nltk()
# Initialize ChromaDB with a single collection for all document types
self.chroma_client = chromadb.Client()
existing_collections = self.chroma_client.list_collections()
collection_exists = any(col.name == collection_name for col in existing_collections)
if collection_exists:
print(f"Using existing collection: {collection_name}")
self.collection = self.chroma_client.get_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
else:
print(f"Creating new collection: {collection_name}")
self.collection = self.chroma_client.create_collection(
name=collection_name,
embedding_function=CustomEmbeddingFunction()
)
def _initialize_nltk(self):
"""Ensure both NLTK resources are available."""
try:
nltk.download('punkt')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
except Exception as e:
print(f"Warning: Error downloading NLTK resources: {str(e)}")
print("Falling back to basic sentence splitting...")
def _basic_sentence_split(self, text: str) -> List[str]:
"""Fallback method for sentence tokenization"""
sentences = []
current = ""
for char in text:
current += char
if char in ['.', '!', '?'] and len(current.strip()) > 0:
sentences.append(current.strip())
current = ""
if current.strip():
sentences.append(current.strip())
return sentences
def process_file(self, file_path: str) -> Dict:
"""Process any supported file type"""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.xml':
return self.process_xml_file(file_path)
elif file_extension == '.pdf':
return self.process_pdf_file(file_path)
else:
return {
'success': False,
'error': f'Unsupported file type: {file_extension}'
}
except Exception as e:
return {
'success': False,
'error': f'Error processing file: {str(e)}'
}
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
try:
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + " "
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks while preserving sentence boundaries"""
try:
sentences = sent_tokenize(text)
except Exception as e:
print(f"Warning: Using fallback sentence splitting: {str(e)}")
sentences = self._basic_sentence_split(text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
words = sentence.split()
sentence_size = len(words)
if current_size + sentence_size > self.pdf_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
overlap_words = current_chunk[-self.pdf_overlap:] if self.pdf_overlap > 0 else []
current_chunk = overlap_words + words
current_size = len(current_chunk)
else:
current_chunk = words
current_size = sentence_size
else:
current_chunk.extend(words)
current_size += sentence_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def flatten_xml_to_text(self, element, depth=0) -> str:
"""Convert XML element and its children to a flat text representation"""
text_parts = []
element_info = f"Element: {element.tag}"
if element.attrib:
element_info += f", Attributes: {json.dumps(element.attrib)}"
if element.text and element.text.strip():
element_info += f", Text: {element.text.strip()}"
text_parts.append(element_info)
for child in element:
child_text = self.flatten_xml_to_text(child, depth + 1)
text_parts.append(child_text)
return "\n".join(text_parts)
def chunk_xml_text(self, text: str, max_chunk_size: int = 2000) -> List[str]:
"""Split flattened XML text into manageable chunks"""
lines = text.split('\n')
chunks = []
current_chunk = []
current_size = 0
for line in lines:
line_size = len(line)
if current_size + line_size > max_chunk_size and current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def generate_natural_language(self, content: Union[List[Dict], str], content_type: str) -> str:
"""Generate natural language description with improved error handling and chunking"""
try:
if content_type == "xml":
prompt = f"Convert this XML structure description to a natural language summary: {content}"
else: # pdf
prompt = f"Summarize this text while preserving key information: {content}"
max_prompt_length = 4000
if len(prompt) > max_prompt_length:
prompt = prompt[:max_prompt_length] + "..."
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
print(f"Error generating natural language: {str(e)}")
if len(content) > 2000:
half_length = len(content) // 2
first_half = content[:half_length]
try:
return self.generate_natural_language(first_half, content_type)
except:
return None
return None
def store_in_vector_db(self, natural_language: str, metadata: Dict) -> str:
"""Store content in vector database"""
doc_id = f"{metadata['source_file']}_{metadata['content_type']}_{metadata['chunk_id']}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.collection.add(
documents=[natural_language],
metadatas=[metadata],
ids=[doc_id]
)
return doc_id
def process_xml_file(self, xml_file_path: str) -> Dict:
"""Process XML file with improved chunking"""
try:
tree = ET.parse(xml_file_path)
root = tree.getroot()
flattened_text = self.flatten_xml_to_text(root)
chunks = self.chunk_xml_text(flattened_text)
print(f"Split XML into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing XML chunk {i+1}/{len(chunks)}")
try:
natural_language = self.generate_natural_language(chunk, "xml")
if natural_language:
metadata = {
'source_file': os.path.basename(xml_file_path),
'content_type': 'xml',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language'
})
except Exception as e:
print(f"Error processing chunk {i}: {str(e)}")
results.append({
'chunk': i,
'success': False,
'error': str(e)
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def process_pdf_file(self, pdf_file_path: str) -> Dict:
"""Process PDF file"""
try:
full_text = self.extract_text_from_pdf(pdf_file_path)
chunks = self.chunk_text(full_text)
print(f"Split PDF into {len(chunks)} chunks")
results = []
for i, chunk in enumerate(chunks):
print(f"Processing PDF chunk {i+1}/{len(chunks)}")
natural_language = self.generate_natural_language(chunk, "pdf")
if natural_language:
metadata = {
'source_file': os.path.basename(pdf_file_path),
'content_type': 'pdf',
'chunk_id': i,
'total_chunks': len(chunks),
'timestamp': str(datetime.datetime.now()),
'chunk_size': len(chunk.split())
}
doc_id = self.store_in_vector_db(natural_language, metadata)
results.append({
'chunk': i,
'success': True,
'doc_id': doc_id,
'natural_language': natural_language,
'original_text': chunk[:200] + "..."
})
else:
results.append({
'chunk': i,
'success': False,
'error': 'Failed to generate natural language summary'
})
return {
'success': True,
'total_chunks': len(chunks),
'results': results
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_available_files(self) -> Dict[str, List[str]]:
"""Get list of all files in the database"""
try:
all_entries = self.collection.get(
include=['metadatas']
)
files = {
'pdf': set(),
'xml': set()
}
for metadata in all_entries['metadatas']:
file_type = metadata['content_type']
file_name = metadata['source_file']
files[file_type].add(file_name)
return {
'pdf': sorted(list(files['pdf'])),
'xml': sorted(list(files['xml']))
}
except Exception as e:
print(f"Error getting available files: {str(e)}")
return {'pdf': [], 'xml': []}
def ask_question_selective(self, question: str, selected_files: List[str], n_results: int = 5) -> str:
"""Ask a question using only the selected files"""
try:
filter_dict = {
'source_file': {'$in': selected_files}
}
results = self.collection.query(
query_texts=[question],
n_results=n_results,
where=filter_dict,
include=["documents", "metadatas"]
)
if not results['documents'][0]:
return "No relevant content found in the selected files."
context = "\n\n".join(results['documents'][0])
prompt = f"""Based on the following content from the selected files, please answer this question: {question}
Content:
{context}
Please provide a direct answer based only on the information provided above."""
response = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0.2
)
return response.choices[0].message.content
except Exception as e:
return f"Error processing your question: {str(e)}"