chagu-demo / rag_sec /document_search_system.py
talexm
update
9c63cfb
raw
history blame
7.67 kB
import os
from pathlib import Path
from chainguard.blockchain_logger import BlockchainLogger
from neo4j import GraphDatabase
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from .bad_query_detector import BadQueryDetector
from .query_transformer import QueryTransformer
from .document_retriver import DocumentRetriever
from .senamtic_response_generator import SemanticResponseGenerator
class DataTransformer:
def __init__(self):
"""
Initializes a DataTransformer with a blockchain logger instance.
"""
self.blockchain_logger = BlockchainLogger()
def secure_transform(self, data):
"""
Securely transforms the input data by logging it into the blockchain.
Args:
data (dict): The log data or any data to be securely transformed.
Returns:
dict: A dictionary containing the original data, block hash, and blockchain length.
"""
# Log the data into the blockchain
block_details = self.blockchain_logger.log_data(data)
# Return the block details and blockchain status
return {
"data": data,
**block_details
}
def validate_blockchain(self):
"""
Validates the integrity of the blockchain.
Returns:
bool: True if the blockchain is valid, False otherwise.
"""
return self.blockchain_logger.is_blockchain_valid()
class Neo4jHandler:
def __init__(self, uri, user, password):
"""
Initializes a Neo4j handler for storing and querying relationships.
"""
self.driver = GraphDatabase.driver(uri, auth=(user, password))
def close(self):
self.driver.close()
def log_relationships(self, query, transformed_query, response, documents):
"""
Logs the relationships between queries, responses, and documents into Neo4j.
"""
with self.driver.session() as session:
session.write_transaction(self._create_and_link_nodes, query, transformed_query, response, documents)
@staticmethod
def _create_and_link_nodes(tx, query, transformed_query, response, documents):
# Create Query node
tx.run("MERGE (q:Query {text: $query}) RETURN q", parameters={"query": query})
# Create TransformedQuery node
tx.run("MERGE (t:TransformedQuery {text: $transformed_query}) RETURN t",
parameters={"transformed_query": transformed_query})
# Create Response node
tx.run("MERGE (r:Response {text: $response}) RETURN r", parameters={"response": response})
# Link Query to TransformedQuery and Response
tx.run(
"""
MATCH (q:Query {text: $query}), (t:TransformedQuery {text: $transformed_query})
MERGE (q)-[:TRANSFORMED_TO]->(t)
""", parameters={"query": query, "transformed_query": transformed_query}
)
tx.run(
"""
MATCH (q:Query {text: $query}), (r:Response {text: $response})
MERGE (q)-[:GENERATED]->(r)
""", parameters={"query": query, "response": response}
)
# Create and link Document nodes
for doc in documents:
tx.run("MERGE (d:Document {name: $doc}) RETURN d", parameters={"doc": doc})
tx.run(
"""
MATCH (q:Query {text: $query}), (d:Document {name: $doc})
MERGE (q)-[:RETRIEVED]->(d)
""", parameters={"query": query, "doc": doc}
)
class DocumentSearchSystem:
def __init__(self, neo4j_uri, neo4j_user, neo4j_password):
"""
Initializes the DocumentSearchSystem with:
- BadQueryDetector for identifying malicious or inappropriate queries.
- QueryTransformer for improving or rephrasing queries.
- DocumentRetriever for semantic document retrieval.
- SemanticResponseGenerator for generating context-aware responses.
- DataTransformer for blockchain logging of queries and responses.
- Neo4jHandler for relationship logging and visualization.
"""
self.detector = BadQueryDetector()
self.transformer = QueryTransformer()
self.retriever = DocumentRetriever()
self.response_generator = SemanticResponseGenerator()
self.data_transformer = DataTransformer()
self.neo4j_handler = Neo4jHandler(neo4j_uri, neo4j_user, neo4j_password)
def process_query(self, query):
"""
Processes a user query through the following steps:
1. Detect if the query is malicious.
2. Transform the query if needed.
3. Retrieve relevant documents based on the query.
4. Generate a response using the retrieved documents.
5. Log all stages to the blockchain and Neo4j.
:param query: The user query as a string.
:return: A dictionary with the status and response or error message.
"""
if self.detector.is_bad_query(query):
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
# Transform the query
transformed_query = self.transformer.transform_query(query)
# Log the original query to the blockchain
self.data_transformer.secure_transform({"type": "query", "content": query})
# Retrieve relevant documents
retrieved_docs = self.retriever.retrieve(transformed_query)
if not retrieved_docs:
return {"status": "no_results", "message": "No relevant documents found for your query."}
# Log the retrieved documents to the blockchain
self.data_transformer.secure_transform({"type": "documents", "content": retrieved_docs})
# Generate a response based on the retrieved documents
response = self.response_generator.generate_response(retrieved_docs)
# Log the response to the blockchain
blockchain_details = self.data_transformer.secure_transform({"type": "response", "content": response})
# Log relationships to Neo4j
self.neo4j_handler.log_relationships(query, transformed_query, response, retrieved_docs)
return {
"status": "success",
"response": response,
"retrieved_documents": retrieved_docs,
"blockchain_details": blockchain_details
}
def validate_system_integrity(self):
"""
Validates the integrity of the blockchain.
"""
return self.data_transformer.validate_blockchain()
if __name__ == "__main__":
home_dir = Path(os.getenv("HOME", "/"))
data_dir = home_dir / "data-sets/aclImdb/train"
# Initialize system with Neo4j credentials
system = DocumentSearchSystem(
neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
neo4j_user="neo4j",
neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
)
system.retriever.load_documents(data_dir)
# Perform a normal query
normal_query = "Good comedy ."
print("\nNormal Query Result:")
result = system.process_query(normal_query)
print("Status:", result["status"])
print("Response:", result["response"])
print("Retrieved Documents:", result["retrieved_documents"])
print("Blockchain Details:", result["blockchain_details"])
# Perform a malicious query
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
print("\nMalicious Query Result:")
result = system.process_query(malicious_query)
print("Status:", result["status"])
print("Message:", result.get("message"))