# rag/rag_pipeline.py import json import logging import os import re from typing import Any, Dict, List, Optional, Tuple import chromadb from dotenv import load_dotenv from llama_index.core import Document, PromptTemplate, VectorStoreIndex from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from llama_index.vector_stores.chroma import ChromaVectorStore logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() class RAGPipeline: def __init__( self, study_json, collection_name="study_files_rag_collection", use_semantic_splitter=False, ): self.study_json = study_json self.collection_name = collection_name self.use_semantic_splitter = use_semantic_splitter self.documents = None self.client = chromadb.Client() self.collection = self.client.get_or_create_collection(self.collection_name) self.embedding_model = OpenAIEmbedding( model_name="text-embedding-ada-002", api_key=os.getenv("OPENAI_API_KEY") ) self.is_pdf = self._check_if_pdf_collection() self.load_documents() self.build_index() def _check_if_pdf_collection(self) -> bool: """Check if this is a PDF collection based on the JSON structure.""" try: with open(self.study_json, "r") as f: data = json.load(f) # Check first document for PDF-specific fields if data and isinstance(data, list) and len(data) > 0: return "pages" in data[0] and "source_file" in data[0] return False except Exception as e: logger.error(f"Error checking collection type: {str(e)}") return False def extract_page_number_from_query(self, query: str) -> int: """Extract page number from query text.""" # Look for patterns like "page 3", "p3", "p. 3", etc. patterns = [ r"page\s*(\d+)", r"p\.\s*(\d+)", r"p\s*(\d+)", r"pg\.\s*(\d+)", r"pg\s*(\d+)", ] for pattern in patterns: match = re.search(pattern, query.lower()) if match: return int(match.group(1)) return None def load_documents(self): if self.documents is None: with open(self.study_json, "r") as f: self.data = json.load(f) self.documents = [] if self.is_pdf: # Handle PDF documents for index, doc_data in enumerate(self.data): pages = doc_data.get("pages", {}) for page_num, page_content in pages.items(): if isinstance(page_content, dict): content = page_content.get("text", "") else: content = page_content doc_content = ( f"Title: {doc_data['title']}\n" f"Page {page_num} Content:\n{content}\n" f"Authors: {', '.join(doc_data['authors'])}\n" ) metadata = { "title": doc_data.get("title"), "authors": ", ".join(doc_data.get("authors", [])), "year": doc_data.get("date"), "source_file": doc_data.get("source_file"), "page_number": int(page_num), "total_pages": doc_data.get("page_count"), } self.documents.append( Document( text=doc_content, id_=f"doc_{index}_page_{page_num}", metadata=metadata, ) ) else: # Handle Zotero documents for index, doc_data in enumerate(self.data): doc_content = ( f"Title: {doc_data.get('title', '')}\n" f"Abstract: {doc_data.get('abstract', '')}\n" f"Authors: {', '.join(doc_data.get('authors', []))}\n" ) metadata = { "title": doc_data.get("title"), "authors": ", ".join(doc_data.get("authors", [])), "year": doc_data.get("date"), "doi": doc_data.get("doi"), } self.documents.append( Document( text=doc_content, id_=f"doc_{index}", metadata=metadata ) ) def build_index(self): sentence_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=20) def _split(text: str) -> List[str]: return sentence_splitter.split_text(text) node_parser = SentenceWindowNodeParser.from_defaults( sentence_splitter=_split, window_size=5, window_metadata_key="window", original_text_metadata_key="original_text", ) # Parse documents into nodes for embedding nodes = node_parser.get_nodes_from_documents(self.documents) # Initialize ChromaVectorStore with the existing collection vector_store = ChromaVectorStore(chroma_collection=self.collection) # Create the VectorStoreIndex using the ChromaVectorStore self.index = VectorStoreIndex( nodes, vector_store=vector_store, embed_model=self.embedding_model ) def query( self, context: str, prompt_template: PromptTemplate = None ) -> Tuple[str, List[Any]]: if prompt_template is None: prompt_template = PromptTemplate( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given this information, please answer the question: {query_str}\n" "Follow these guidelines for your response:\n" "1. If the answer contains multiple pieces of information (e.g., author names, dates, statistics), " "present it in a markdown table format.\n" "2. For single piece information or simple answers, respond in a clear sentence.\n" "3. Always cite sources using square brackets for EVERY piece of information, e.g. [1], [2], etc.\n" "4. If the information spans multiple documents or pages, organize it by source.\n" "5. If you're unsure about something, say so rather than making assumptions.\n" "\nFormat tables like this:\n" "| Field | Information | Source |\n" "|-------|-------------|--------|\n" "| Title | Example Title | [1] |\n" ) # Extract page number for PDF documents requested_page = ( self.extract_page_number_from_query(context) if self.is_pdf else None ) n_documents = len(self.index.docstore.docs) print(f"n_documents: {n_documents}") query_engine = self.index.as_query_engine( text_qa_template=prompt_template, similarity_top_k=n_documents if n_documents <= 17 else 15, response_mode="tree_summarize", llm=OpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY")), ) response = query_engine.query(context) # Debug logging print(f"Response type: {type(response)}") print(f"Has source_nodes: {hasattr(response, 'source_nodes')}") if hasattr(response, 'source_nodes'): print(f"Number of source nodes: {len(response.source_nodes)}") return response.response, getattr(response, 'source_nodes', [])