talexm commited on
Commit
0a4227c
1 Parent(s): e512ea0
app.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  from PIL import Image
5
  from rag_sec.document_search_system import DocumentSearchSystem
6
  from chainguard.blockchain_logger import BlockchainLogger
 
7
 
8
  # Blockchain Logger
9
  blockchain_logger = BlockchainLogger()
@@ -64,14 +65,29 @@ if st.button("Validate Blockchain Integrity"):
64
 
65
  # Query System
66
  st.subheader("Query Files")
67
- query = st.text_input("Enter your query (e.g., 'Good comedy')")
 
 
 
68
  if st.button("Search"):
69
- result = system.process_query(query)
70
- st.write("Query Status:", result.get("status"))
71
- st.write("Query Response:", result.get("response"))
72
- if "retrieved_documents" in result:
73
- st.write("Retrieved Documents:", result["retrieved_documents"])
74
- if "blockchain_details" in result:
75
- st.write("Blockchain Details:", result["blockchain_details"])
76
- if result.get("status") == "rejected":
77
- st.error(f"Query Blocked: {result.get('message')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
  from rag_sec.document_search_system import DocumentSearchSystem
6
  from chainguard.blockchain_logger import BlockchainLogger
7
+ from rag_sec.document_search_system import main
8
 
9
  # Blockchain Logger
10
  blockchain_logger = BlockchainLogger()
 
65
 
66
  # Query System
67
  st.subheader("Query Files")
68
+ system = main() # Initialize system with Neo4j and load documents
69
+
70
+ # Query Input
71
+ query = st.text_input("Enter your query", placeholder="E.g., 'Good comedy'")
72
  if st.button("Search"):
73
+ if query:
74
+ # Process the query
75
+ result = system.process_query(query)
76
+
77
+ # Display the results
78
+ st.write("Query Status:", result.get("status"))
79
+ st.write("Query Response:", result.get("response"))
80
+
81
+ if "retrieved_documents" in result:
82
+ st.write("Retrieved Documents:")
83
+ for doc in result["retrieved_documents"]:
84
+ st.markdown(f"- {doc}")
85
+
86
+ if "blockchain_details" in result:
87
+ st.write("Blockchain Details:")
88
+ st.json(result["blockchain_details"])
89
+
90
+ if result.get("status") == "rejected":
91
+ st.error(f"Query Blocked: {result.get('message')}")
92
+ else:
93
+ st.warning("Please enter a query to search.")
rag_sec/document_search_system.py CHANGED
@@ -7,10 +7,10 @@ import sys
7
  from os import path
8
 
9
  sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
10
- from .bad_query_detector import BadQueryDetector
11
- from .query_transformer import QueryTransformer
12
- from .document_retriver import DocumentRetriever
13
- from .senamtic_response_generator import SemanticResponseGenerator
14
 
15
 
16
  class DataTransformer:
@@ -171,12 +171,11 @@ class DocumentSearchSystem:
171
  return self.data_transformer.validate_blockchain()
172
 
173
 
174
- if __name__ == "__main__":
175
-
176
  home_dir = Path(os.getenv("HOME", "/"))
177
  data_dir = home_dir / "data-sets/aclImdb/train"
178
 
179
-
180
  # Initialize system with Neo4j credentials
181
  system = DocumentSearchSystem(
182
  neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
@@ -184,21 +183,42 @@ if __name__ == "__main__":
184
  neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
185
  )
186
 
 
187
  system.retriever.load_documents(data_dir)
188
- # Perform a normal query
189
- normal_query = "Good comedy ."
190
- print("\nNormal Query Result:")
191
- result = system.process_query(normal_query)
192
- print("Status:", result["status"])
193
- print("Response:", result["response"])
194
- print("Retrieved Documents:", result["retrieved_documents"])
195
- print("Blockchain Details:", result["blockchain_details"])
196
-
197
- # Perform a malicious query
198
- malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
199
- print("\nMalicious Query Result:")
200
- result = system.process_query(malicious_query)
201
- print("Status:", result["status"])
202
- print("Message:", result.get("message"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
 
7
  from os import path
8
 
9
  sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
10
+ from bad_query_detector import BadQueryDetector
11
+ from query_transformer import QueryTransformer
12
+ from document_retriver import DocumentRetriever
13
+ from senamtic_response_generator import SemanticResponseGenerator
14
 
15
 
16
  class DataTransformer:
 
171
  return self.data_transformer.validate_blockchain()
172
 
173
 
174
+ def main():
175
+ # Path to the dataset directory
176
  home_dir = Path(os.getenv("HOME", "/"))
177
  data_dir = home_dir / "data-sets/aclImdb/train"
178
 
 
179
  # Initialize system with Neo4j credentials
180
  system = DocumentSearchSystem(
181
  neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
 
183
  neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
184
  )
185
 
186
+ # Load documents into the retriever
187
  system.retriever.load_documents(data_dir)
188
+ print("Documents successfully loaded.")
189
+
190
+ return system
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
195
+
196
+ # home_dir = Path(os.getenv("HOME", "/"))
197
+ # data_dir = home_dir / "data-sets/aclImdb/train"
198
+ #
199
+ #
200
+ # # Initialize system with Neo4j credentials
201
+ # system = DocumentSearchSystem(
202
+ # neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
203
+ # neo4j_user="neo4j",
204
+ # neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
205
+ # )
206
+ #
207
+ # system.retriever.load_documents(data_dir)
208
+ # # Perform a normal query
209
+ # normal_query = "Good comedy ."
210
+ # print("\nNormal Query Result:")
211
+ # result = system.process_query(normal_query)
212
+ # print("Status:", result["status"])
213
+ # print("Response:", result["response"])
214
+ # print("Retrieved Documents:", result["retrieved_documents"])
215
+ # print("Blockchain Details:", result["blockchain_details"])
216
+ #
217
+ # # Perform a malicious query
218
+ # malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
219
+ # print("\nMalicious Query Result:")
220
+ # result = system.process_query(malicious_query)
221
+ # print("Status:", result["status"])
222
+ # print("Message:", result.get("message"))
223
 
224
 
rag_sec/requirements.txt CHANGED
@@ -4,3 +4,4 @@ numpy
4
  scikit-learn
5
  faiss-cpu
6
  pandas
 
 
4
  scikit-learn
5
  faiss-cpu
6
  pandas
7
+ transformers