talexm commited on
Commit
e9a8c67
1 Parent(s): f861dee
rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc CHANGED
Binary files a/rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc and b/rag_sec/__pycache__/rag_chagu_demo.cpython-38-pytest-8.3.2.pyc differ
 
rag_sec/rag_chagu_demo.py CHANGED
@@ -1,104 +1,101 @@
1
- from transformers import pipeline
2
- from difflib import get_close_matches
3
- from pathlib import Path
4
  import os
 
 
 
5
 
6
 
7
- class BadQueryDetector:
8
  def __init__(self):
9
- self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
 
 
10
 
11
- def is_bad_query(self, query):
12
- result = self.detector(query)[0]
13
- label = result["label"]
14
- score = result["score"]
15
- # Mark queries as malicious or bad if negative sentiment with high confidence
16
- if label == "NEGATIVE" and score > 0.8:
17
- print(f"Detected malicious query with high confidence ({score:.4f}): {query}")
18
- return True
19
- return False
20
 
 
 
21
 
22
- class QueryTransformer:
23
- def transform_query(self, query):
24
- # Simple transformation example: rephrasing and clarifying
25
- # In practice, this could involve more sophisticated models like T5
26
- if "DROP TABLE" in query or "SELECT *" in query:
27
- return "Your query appears to contain SQL injection elements. Please rephrase."
28
- # Add more sophisticated handling here
29
- return query
30
 
 
 
 
31
 
32
- class DocumentRetriever:
33
- def __init__(self):
34
- self.documents = []
35
 
36
- def load_documents(self, source_dir):
37
- data_dir = Path(source_dir)
38
- if not data_dir.exists():
39
- print(f"Source directory not found: {source_dir}")
40
- return
41
-
42
- for file in data_dir.glob("*.txt"):
43
- with open(file, "r", encoding="utf-8") as f:
44
- self.documents.append(f.read())
45
-
46
- print(f"Loaded {len(self.documents)} documents.")
47
 
48
- def retrieve(self, query):
49
- matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
50
- return matches if matches else ["No matching documents found."]
 
51
 
 
 
 
52
 
53
- class SemanticResponseGenerator:
54
- def __init__(self):
55
- self.generator = pipeline("text-generation", model="gpt2")
56
 
57
- def generate_response(self, retrieved_docs):
58
- # Generate a semantic response using retrieved documents
59
- combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches for response
60
- response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
61
- return response[0]["generated_text"]
62
 
 
 
 
 
 
63
 
64
- class DocumentSearchSystem:
65
- def __init__(self):
66
- self.detector = BadQueryDetector()
67
- self.transformer = QueryTransformer()
68
- self.retriever = DocumentRetriever()
69
- self.response_generator = SemanticResponseGenerator()
70
 
71
- def process_query(self, query):
72
- if self.detector.is_bad_query(query):
73
- return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
74
 
75
- transformed_query = self.transformer.transform_query(query)
76
- retrieved_docs = self.retriever.retrieve(transformed_query)
77
 
78
- if "No matching documents found." in retrieved_docs:
79
- return {"status": "no_results", "message": "No relevant documents found for your query."}
80
 
81
- response = self.response_generator.generate_response(retrieved_docs)
82
- return {"status": "success", "response": response}
83
 
 
 
 
84
 
85
- # Test the enhanced system
86
- def test_system():
87
- system = DocumentSearchSystem()
88
- system.retriever.load_documents("/path/to/documents")
89
 
90
- # Test with a normal query
91
- normal_query = "Tell me about great acting performances."
92
- normal_result = system.process_query(normal_query)
93
- print("\nNormal Query Result:")
94
- print(normal_result)
95
 
96
- # Test with a malicious query
97
- malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
98
- malicious_result = system.process_query(malicious_query)
99
- print("\nMalicious Query Result:")
100
- print(malicious_result)
 
101
 
 
 
 
 
 
 
102
 
103
  if __name__ == "__main__":
104
- test_system()
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
+ from difflib import get_close_matches
4
+ from transformers import pipeline
5
 
6
 
7
+ class DocumentSearcher:
8
  def __init__(self):
9
+ self.documents = []
10
+ # Load a pre-trained model for malicious intent detection
11
+ self.malicious_detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
12
 
13
+ def load_imdb_data(self):
14
+ home_dir = Path(os.getenv("HOME", "/"))
15
+ data_dir = home_dir / "data-sets/aclImdb/train"
16
+ pos_dir = data_dir / "pos"
17
+ neg_dir = data_dir / "neg"
 
 
 
 
18
 
19
+ print(f"Looking for positive reviews in: {pos_dir}")
20
+ print(f"Looking for negative reviews in: {neg_dir}")
21
 
22
+ if not pos_dir.exists() or not any(pos_dir.iterdir()):
23
+ print("No positive reviews found.")
24
+ if not neg_dir.exists() or not any(neg_dir.iterdir()):
25
+ print("No negative reviews found.")
 
 
 
 
26
 
27
+ for filename in pos_dir.iterdir():
28
+ with open(filename, "r", encoding="utf-8") as file:
29
+ self.documents.append(file.read())
30
 
31
+ for filename in neg_dir.iterdir():
32
+ with open(filename, "r", encoding="utf-8") as file:
33
+ self.documents.append(file.read())
34
 
35
+ print(f"Loaded {len(self.documents)} movie reviews from IMDB dataset.")
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def load_txt_files(self, txt_dir=None):
38
+ if txt_dir is None:
39
+ home_dir = Path(os.getenv("HOME", "/"))
40
+ txt_dir = home_dir / "data-sets/txt-files/"
41
 
42
+ if not txt_dir.exists():
43
+ print("No .txt files directory found.")
44
+ return
45
 
46
+ for filename in txt_dir.glob("*.txt"):
47
+ with open(filename, "r", encoding="utf-8") as file:
48
+ self.documents.append(file.read())
49
 
50
+ print(f"Loaded additional {len(self.documents)} documents from .txt files.")
 
 
 
 
51
 
52
+ def is_query_malicious(self, query):
53
+ # Use the pre-trained model to check if the query has malicious intent
54
+ result = self.malicious_detector(query)[0]
55
+ label = result['label']
56
+ score = result['score']
57
 
58
+ # Consider the query malicious if the sentiment is negative with high confidence
59
+ if label == "NEGATIVE" and score > 0.8:
60
+ print(f"Warning: Malicious query detected - Confidence: {score:.4f}")
61
+ return True
62
+ return False
 
63
 
64
+ def search_documents(self, query):
65
+ if self.is_query_malicious(query):
66
+ return [{"document": "ANOMALY: Query blocked due to detected malicious intent.", "similarity": 0.0}]
67
 
68
+ # Use fuzzy matching for normal queries
69
+ matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
70
 
71
+ if not matches:
72
+ return [{"document": "No matching documents found.", "similarity": 0.0}]
73
 
74
+ return [{"document": match[:100] + "..."} for match in matches]
 
75
 
76
+ # Test the system with normal and malicious queries
77
+ def test_document_search():
78
+ searcher = DocumentSearcher()
79
 
80
+ # Load the IMDB movie reviews
81
+ searcher.load_imdb_data()
 
 
82
 
83
+ # Load additional .txt files
84
+ searcher.load_txt_files()
 
 
 
85
 
86
+ # Perform a normal query
87
+ normal_query = "This movie had great acting and a compelling storyline."
88
+ normal_results = searcher.search_documents(normal_query)
89
+ print("Normal Query Results:")
90
+ for result in normal_results:
91
+ print(f"Document: {result['document']}")
92
 
93
+ # Perform a query injection attack
94
+ malicious_query = "DROP TABLE reviews; SELECT * FROM confidential_data;"
95
+ attack_results = searcher.search_documents(malicious_query)
96
+ print("\nMalicious Query Results:")
97
+ for result in attack_results:
98
+ print(f"Document: {result['document']}")
99
 
100
  if __name__ == "__main__":
101
+ test_document_search()