Spaces:
Running
Running
import os | |
from pathlib import Path | |
from .bad_query_detector import BadQueryDetector | |
from .query_transformer import QueryTransformer | |
from .document_retriver import DocumentRetriever | |
from .senamtic_response_generator import SemanticResponseGenerator | |
class DocumentSearchSystem: | |
def __init__(self): | |
""" | |
Initializes the DocumentSearchSystem with: | |
- BadQueryDetector for identifying malicious or inappropriate queries. | |
- QueryTransformer for improving or rephrasing queries. | |
- DocumentRetriever for semantic document retrieval. | |
- SemanticResponseGenerator for generating context-aware responses. | |
""" | |
self.detector = BadQueryDetector() | |
self.transformer = QueryTransformer() | |
self.retriever = DocumentRetriever() | |
self.response_generator = SemanticResponseGenerator() | |
def process_query(self, query): | |
""" | |
Processes a user query through the following steps: | |
1. Detect if the query is malicious. | |
2. Transform the query if needed. | |
3. Retrieve relevant documents based on the query. | |
4. Generate a response using the retrieved documents. | |
:param query: The user query as a string. | |
:return: A dictionary with the status and response or error message. | |
""" | |
if self.detector.is_bad_query(query): | |
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."} | |
# Transform the query | |
transformed_query = self.transformer.transform_query(query) | |
print(f"Transformed Query: {transformed_query}") | |
# Retrieve relevant documents | |
retrieved_docs = self.retriever.retrieve(transformed_query) | |
if not retrieved_docs: | |
return {"status": "no_results", "message": "No relevant documents found for your query."} | |
# Generate a response based on the retrieved documents | |
response = self.response_generator.generate_response(retrieved_docs) | |
return {"status": "success", "response": response} | |
def test_system(): | |
""" | |
Test the DocumentSearchSystem with normal and malicious queries. | |
- Load documents from a dataset directory. | |
- Perform a normal query and display results. | |
- Perform a malicious query to ensure proper blocking. | |
""" | |
# Define the path to the dataset directory | |
home_dir = Path(os.getenv("HOME", "/")) | |
data_dir = home_dir / "data-sets/aclImdb/train" | |
# Initialize the system | |
system = DocumentSearchSystem() | |
system.retriever.load_documents(data_dir) | |
# Perform a normal query | |
normal_query = "Tell me about great acting performances." | |
print("\nNormal Query Result:") | |
print(system.process_query(normal_query)) | |
# Perform a malicious query | |
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;" | |
print("\nMalicious Query Result:") | |
print(system.process_query(malicious_query)) | |
if __name__ == "__main__": | |
test_system() | |