chagu-demo / rag_sec /document_search_system.py
talexm
update
c3c1187
raw
history blame
2.99 kB
import os
from pathlib import Path
from .bad_query_detector import BadQueryDetector
from .query_transformer import QueryTransformer
from .document_retriver import DocumentRetriever
from .senamtic_response_generator import SemanticResponseGenerator
class DocumentSearchSystem:
def __init__(self):
"""
Initializes the DocumentSearchSystem with:
- BadQueryDetector for identifying malicious or inappropriate queries.
- QueryTransformer for improving or rephrasing queries.
- DocumentRetriever for semantic document retrieval.
- SemanticResponseGenerator for generating context-aware responses.
"""
self.detector = BadQueryDetector()
self.transformer = QueryTransformer()
self.retriever = DocumentRetriever()
self.response_generator = SemanticResponseGenerator()
def process_query(self, query):
"""
Processes a user query through the following steps:
1. Detect if the query is malicious.
2. Transform the query if needed.
3. Retrieve relevant documents based on the query.
4. Generate a response using the retrieved documents.
:param query: The user query as a string.
:return: A dictionary with the status and response or error message.
"""
if self.detector.is_bad_query(query):
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
# Transform the query
transformed_query = self.transformer.transform_query(query)
print(f"Transformed Query: {transformed_query}")
# Retrieve relevant documents
retrieved_docs = self.retriever.retrieve(transformed_query)
if not retrieved_docs:
return {"status": "no_results", "message": "No relevant documents found for your query."}
# Generate a response based on the retrieved documents
response = self.response_generator.generate_response(retrieved_docs)
return {"status": "success", "response": response}
def test_system():
"""
Test the DocumentSearchSystem with normal and malicious queries.
- Load documents from a dataset directory.
- Perform a normal query and display results.
- Perform a malicious query to ensure proper blocking.
"""
# Define the path to the dataset directory
home_dir = Path(os.getenv("HOME", "/"))
data_dir = home_dir / "data-sets/aclImdb/train"
# Initialize the system
system = DocumentSearchSystem()
system.retriever.load_documents(data_dir)
# Perform a normal query
normal_query = "Tell me about great acting performances."
print("\nNormal Query Result:")
print(system.process_query(normal_query))
# Perform a malicious query
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
print("\nMalicious Query Result:")
print(system.process_query(malicious_query))
if __name__ == "__main__":
test_system()