Spaces:
Running
Running
talexm
commited on
Commit
•
0a4227c
1
Parent(s):
e512ea0
update
Browse files- app.py +26 -10
- rag_sec/document_search_system.py +42 -22
- rag_sec/requirements.txt +1 -0
app.py
CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
4 |
from PIL import Image
|
5 |
from rag_sec.document_search_system import DocumentSearchSystem
|
6 |
from chainguard.blockchain_logger import BlockchainLogger
|
|
|
7 |
|
8 |
# Blockchain Logger
|
9 |
blockchain_logger = BlockchainLogger()
|
@@ -64,14 +65,29 @@ if st.button("Validate Blockchain Integrity"):
|
|
64 |
|
65 |
# Query System
|
66 |
st.subheader("Query Files")
|
67 |
-
|
|
|
|
|
|
|
68 |
if st.button("Search"):
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
st.write("
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
from rag_sec.document_search_system import DocumentSearchSystem
|
6 |
from chainguard.blockchain_logger import BlockchainLogger
|
7 |
+
from rag_sec.document_search_system import main
|
8 |
|
9 |
# Blockchain Logger
|
10 |
blockchain_logger = BlockchainLogger()
|
|
|
65 |
|
66 |
# Query System
|
67 |
st.subheader("Query Files")
|
68 |
+
system = main() # Initialize system with Neo4j and load documents
|
69 |
+
|
70 |
+
# Query Input
|
71 |
+
query = st.text_input("Enter your query", placeholder="E.g., 'Good comedy'")
|
72 |
if st.button("Search"):
|
73 |
+
if query:
|
74 |
+
# Process the query
|
75 |
+
result = system.process_query(query)
|
76 |
+
|
77 |
+
# Display the results
|
78 |
+
st.write("Query Status:", result.get("status"))
|
79 |
+
st.write("Query Response:", result.get("response"))
|
80 |
+
|
81 |
+
if "retrieved_documents" in result:
|
82 |
+
st.write("Retrieved Documents:")
|
83 |
+
for doc in result["retrieved_documents"]:
|
84 |
+
st.markdown(f"- {doc}")
|
85 |
+
|
86 |
+
if "blockchain_details" in result:
|
87 |
+
st.write("Blockchain Details:")
|
88 |
+
st.json(result["blockchain_details"])
|
89 |
+
|
90 |
+
if result.get("status") == "rejected":
|
91 |
+
st.error(f"Query Blocked: {result.get('message')}")
|
92 |
+
else:
|
93 |
+
st.warning("Please enter a query to search.")
|
rag_sec/document_search_system.py
CHANGED
@@ -7,10 +7,10 @@ import sys
|
|
7 |
from os import path
|
8 |
|
9 |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from
|
13 |
-
from
|
14 |
|
15 |
|
16 |
class DataTransformer:
|
@@ -171,12 +171,11 @@ class DocumentSearchSystem:
|
|
171 |
return self.data_transformer.validate_blockchain()
|
172 |
|
173 |
|
174 |
-
|
175 |
-
|
176 |
home_dir = Path(os.getenv("HOME", "/"))
|
177 |
data_dir = home_dir / "data-sets/aclImdb/train"
|
178 |
|
179 |
-
|
180 |
# Initialize system with Neo4j credentials
|
181 |
system = DocumentSearchSystem(
|
182 |
neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
|
@@ -184,21 +183,42 @@ if __name__ == "__main__":
|
|
184 |
neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
|
185 |
)
|
186 |
|
|
|
187 |
system.retriever.load_documents(data_dir)
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
#
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
|
|
7 |
from os import path
|
8 |
|
9 |
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
10 |
+
from bad_query_detector import BadQueryDetector
|
11 |
+
from query_transformer import QueryTransformer
|
12 |
+
from document_retriver import DocumentRetriever
|
13 |
+
from senamtic_response_generator import SemanticResponseGenerator
|
14 |
|
15 |
|
16 |
class DataTransformer:
|
|
|
171 |
return self.data_transformer.validate_blockchain()
|
172 |
|
173 |
|
174 |
+
def main():
|
175 |
+
# Path to the dataset directory
|
176 |
home_dir = Path(os.getenv("HOME", "/"))
|
177 |
data_dir = home_dir / "data-sets/aclImdb/train"
|
178 |
|
|
|
179 |
# Initialize system with Neo4j credentials
|
180 |
system = DocumentSearchSystem(
|
181 |
neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
|
|
|
183 |
neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
|
184 |
)
|
185 |
|
186 |
+
# Load documents into the retriever
|
187 |
system.retriever.load_documents(data_dir)
|
188 |
+
print("Documents successfully loaded.")
|
189 |
+
|
190 |
+
return system
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|
195 |
+
|
196 |
+
# home_dir = Path(os.getenv("HOME", "/"))
|
197 |
+
# data_dir = home_dir / "data-sets/aclImdb/train"
|
198 |
+
#
|
199 |
+
#
|
200 |
+
# # Initialize system with Neo4j credentials
|
201 |
+
# system = DocumentSearchSystem(
|
202 |
+
# neo4j_uri="neo4j+s://0ca71b10.databases.neo4j.io",
|
203 |
+
# neo4j_user="neo4j",
|
204 |
+
# neo4j_password="HwGDOxyGS1-79nLeTiX5bx5ohoFSpvHCmTv8IRgt-lY"
|
205 |
+
# )
|
206 |
+
#
|
207 |
+
# system.retriever.load_documents(data_dir)
|
208 |
+
# # Perform a normal query
|
209 |
+
# normal_query = "Good comedy ."
|
210 |
+
# print("\nNormal Query Result:")
|
211 |
+
# result = system.process_query(normal_query)
|
212 |
+
# print("Status:", result["status"])
|
213 |
+
# print("Response:", result["response"])
|
214 |
+
# print("Retrieved Documents:", result["retrieved_documents"])
|
215 |
+
# print("Blockchain Details:", result["blockchain_details"])
|
216 |
+
#
|
217 |
+
# # Perform a malicious query
|
218 |
+
# malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
|
219 |
+
# print("\nMalicious Query Result:")
|
220 |
+
# result = system.process_query(malicious_query)
|
221 |
+
# print("Status:", result["status"])
|
222 |
+
# print("Message:", result.get("message"))
|
223 |
|
224 |
|
rag_sec/requirements.txt
CHANGED
@@ -4,3 +4,4 @@ numpy
|
|
4 |
scikit-learn
|
5 |
faiss-cpu
|
6 |
pandas
|
|
|
|
4 |
scikit-learn
|
5 |
faiss-cpu
|
6 |
pandas
|
7 |
+
transformers
|