Spaces:
Running
Running
import os | |
from pathlib import Path | |
from chainguard.blockchain_logger import BlockchainLogger | |
from neo4j import GraphDatabase | |
import sys | |
from os import path | |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) | |
from .bad_query_detector import BadQueryDetector | |
from .query_transformer import QueryTransformer | |
from .document_retriver import DocumentRetriever | |
from .senamtic_response_generator import SemanticResponseGenerator | |
class DataTransformer: | |
def __init__(self): | |
""" | |
Initializes a DataTransformer with a blockchain logger instance. | |
""" | |
self.blockchain_logger = BlockchainLogger() | |
def secure_transform(self, data): | |
""" | |
Securely transforms the input data by logging it into the blockchain. | |
Args: | |
data (dict): The log data or any data to be securely transformed. | |
Returns: | |
dict: A dictionary containing the original data, block hash, and blockchain length. | |
""" | |
# Log the data into the blockchain | |
block_details = self.blockchain_logger.log_data(data) | |
# Return the block details and blockchain status | |
return { | |
"data": data, | |
**block_details | |
} | |
def validate_blockchain(self): | |
""" | |
Validates the integrity of the blockchain. | |
Returns: | |
bool: True if the blockchain is valid, False otherwise. | |
""" | |
return self.blockchain_logger.is_blockchain_valid() | |
class Neo4jHandler: | |
def __init__(self, uri, user, password): | |
""" | |
Initializes a Neo4j handler for storing and querying relationships. | |
""" | |
self.driver = GraphDatabase.driver(uri, auth=(user, password)) | |
def close(self): | |
self.driver.close() | |
def log_relationships(self, query, transformed_query, response, documents): | |
""" | |
Logs the relationships between queries, responses, and documents into Neo4j. | |
""" | |
with self.driver.session() as session: | |
session.write_transaction(self._create_and_link_nodes, query, transformed_query, response, documents) | |
def _create_and_link_nodes(tx, query, transformed_query, response, documents): | |
# Create Query node | |
tx.run("MERGE (q:Query {text: $query}) RETURN q", parameters={"query": query}) | |
# Create TransformedQuery node | |
tx.run("MERGE (t:TransformedQuery {text: $transformed_query}) RETURN t", | |
parameters={"transformed_query": transformed_query}) | |
# Create Response node | |
tx.run("MERGE (r:Response {text: $response}) RETURN r", parameters={"response": response}) | |
# Link Query to TransformedQuery and Response | |
tx.run( | |
""" | |
MATCH (q:Query {text: $query}), (t:TransformedQuery {text: $transformed_query}) | |
MERGE (q)-[:TRANSFORMED_TO]->(t) | |
""", parameters={"query": query, "transformed_query": transformed_query} | |
) | |
tx.run( | |
""" | |
MATCH (q:Query {text: $query}), (r:Response {text: $response}) | |
MERGE (q)-[:GENERATED]->(r) | |
""", parameters={"query": query, "response": response} | |
) | |
# Create and link Document nodes | |
for doc in documents: | |
tx.run("MERGE (d:Document {name: $doc}) RETURN d", parameters={"doc": doc}) | |
tx.run( | |
""" | |
MATCH (q:Query {text: $query}), (d:Document {name: $doc}) | |
MERGE (q)-[:RETRIEVED]->(d) | |
""", parameters={"query": query, "doc": doc} | |
) | |
class DocumentSearchSystem: | |
def __init__(self, neo4j_uri, neo4j_user, neo4j_password): | |
""" | |
Initializes the DocumentSearchSystem with: | |
- BadQueryDetector for identifying malicious or inappropriate queries. | |
- QueryTransformer for improving or rephrasing queries. | |
- DocumentRetriever for semantic document retrieval. | |
- SemanticResponseGenerator for generating context-aware responses. | |
- DataTransformer for blockchain logging of queries and responses. | |
- Neo4jHandler for relationship logging and visualization. | |
""" | |
self.detector = BadQueryDetector() | |
self.transformer = QueryTransformer() | |
self.retriever = DocumentRetriever() | |
self.response_generator = SemanticResponseGenerator() | |
self.data_transformer = DataTransformer() | |
self.neo4j_handler = Neo4jHandler(neo4j_uri, neo4j_user, neo4j_password) | |
def process_query(self, query): | |
""" | |
Processes a user query through the following steps: | |
1. Detect if the query is malicious. | |
2. Transform the query if needed. | |
3. Retrieve relevant documents based on the query. | |
4. Generate a response using the retrieved documents. | |
5. Log all stages to the blockchain and Neo4j. | |
:param query: The user query as a string. | |
:return: A dictionary with the status and response or error message. | |
""" | |
if self.detector.is_bad_query(query): | |
return {"status": "rejected", "message": "Query blocked due to detected malicious intent."} | |
# Transform the query | |
transformed_query = self.transformer.transform_query(query) | |
# Log the original query to the blockchain | |
self.data_transformer.secure_transform({"type": "query", "content": query}) | |
# Retrieve relevant documents | |
retrieved_docs = self.retriever.retrieve(transformed_query) | |
if not retrieved_docs: | |
return {"status": "no_results", "message": "No relevant documents found for your query."} | |
# Log the retrieved documents to the blockchain | |
self.data_transformer.secure_transform({"type": "documents", "content": retrieved_docs}) | |
# Generate a response based on the retrieved documents | |
response = self.response_generator.generate_response(retrieved_docs) | |
# Log the response to the blockchain | |
blockchain_details = self.data_transformer.secure_transform({"type": "response", "content": response}) | |
# Log relationships to Neo4j | |
self.neo4j_handler.log_relationships(query, transformed_query, response, retrieved_docs) | |
return { | |
"status": "success", | |
"response": response, | |
"retrieved_documents": retrieved_docs, | |
"blockchain_details": blockchain_details | |
} | |
def validate_system_integrity(self): | |
""" | |
Validates the integrity of the blockchain. | |
""" | |
return self.data_transformer.validate_blockchain() | |
if __name__ == "__main__": | |
home_dir = Path(os.getenv("HOME", "/")) | |
data_dir = home_dir / "data-sets/aclImdb/train" | |
# Initialize system with Neo4j credentials | |
system = DocumentSearchSystem( | |
neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io", | |
neo4j_user="neo4j", | |
neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY" | |
) | |
system.retriever.load_documents(data_dir) | |
# Perform a normal query | |
normal_query = "Good comedy ." | |
print("\nNormal Query Result:") | |
result = system.process_query(normal_query) | |
print("Status:", result["status"]) | |
print("Response:", result["response"]) | |
print("Retrieved Documents:", result["retrieved_documents"]) | |
print("Blockchain Details:", result["blockchain_details"]) | |
# Perform a malicious query | |
malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;" | |
print("\nMalicious Query Result:") | |
result = system.process_query(malicious_query) | |
print("Status:", result["status"]) | |
print("Message:", result.get("message")) | |