talexm commited on
Commit
73321dd
1 Parent(s): 92c34be

update RAG query improvements

Browse files
.gitignore CHANGED
@@ -1,4 +1,3 @@
1
  rag_sec/__pycache*
2
-
3
- rag_sec/__pycache__/rag_chagu_demo.*
4
-
 
1
  rag_sec/__pycache*
2
+ falocon_api/embeddings.db
3
+ rag_sec/__pycache__/rag_chagu_demo*
 
falocon_api/embeddingGenerator.py CHANGED
@@ -92,7 +92,7 @@ if __name__ == "__main__":
92
  embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
93
 
94
  # Perform a search query
95
- query = "What can be used for document search?"
96
  results = embedding_generator.find_most_similar(query, top_k=3)
97
 
98
  print("Search Results:")
 
92
  embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
93
 
94
  # Perform a search query
95
+ query = "What can be used for document search?"#"DROP TABLE reviews; SELECT * FROM confidential_data;"#"What can be used for document search?"
96
  results = embedding_generator.find_most_similar(query, top_k=3)
97
 
98
  print("Search Results:")
falocon_api/embededGeneratorRAG.py CHANGED
@@ -109,7 +109,7 @@ if __name__ == "__main__":
109
  embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
110
 
111
  # Perform a search query with RAG response generation
112
- query = "find user comments tt0118866"
113
  response = embedding_generator.find_most_similar_and_generate(query)
114
 
115
  print("Generated Response:")
 
109
  embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
110
 
111
  # Perform a search query with RAG response generation
112
+ query = "DROP TABLE reviews; SELECT * FROM confidential_data;"#"find user comments tt0118866"
113
  response = embedding_generator.find_most_similar_and_generate(query)
114
 
115
  print("Generated Response:")
rag_sec/rag_chagu_demo.py CHANGED
@@ -1,100 +1,104 @@
1
- import os
2
- from pathlib import Path
3
- from difflib import get_close_matches
4
  from transformers import pipeline
 
 
 
5
 
6
- class DocumentSearcher:
 
7
  def __init__(self):
8
- self.documents = []
9
- # Load a pre-trained model for malicious intent detection
10
- self.malicious_detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
 
 
 
 
 
 
 
 
11
 
12
- def load_imdb_data(self):
13
- home_dir = Path(os.getenv("HOME", "/"))
14
- data_dir = home_dir / "data-sets/aclImdb/train"
15
- pos_dir = data_dir / "pos"
16
- neg_dir = data_dir / "neg"
17
 
18
- print(f"Looking for positive reviews in: {pos_dir}")
19
- print(f"Looking for negative reviews in: {neg_dir}")
 
 
 
 
 
 
20
 
21
- if not pos_dir.exists() or not any(pos_dir.iterdir()):
22
- print("No positive reviews found.")
23
- if not neg_dir.exists() or not any(neg_dir.iterdir()):
24
- print("No negative reviews found.")
25
 
26
- for filename in pos_dir.iterdir():
27
- with open(filename, "r", encoding="utf-8") as file:
28
- self.documents.append(file.read())
 
 
 
 
 
 
29
 
30
- for filename in neg_dir.iterdir():
31
- with open(filename, "r", encoding="utf-8") as file:
32
- self.documents.append(file.read())
33
 
34
- print(f"Loaded {len(self.documents)} movie reviews from IMDB dataset.")
35
 
36
- def load_txt_files(self, txt_dir=None):
37
- if txt_dir is None:
38
- home_dir = Path(os.getenv("HOME", "/"))
39
- txt_dir = home_dir / "data-sets/txt-files/"
40
 
41
- if not txt_dir.exists():
42
- print("No .txt files directory found.")
43
- return
44
 
45
- for filename in txt_dir.glob("*.txt"):
46
- with open(filename, "r", encoding="utf-8") as file:
47
- self.documents.append(file.read())
48
 
49
- print(f"Loaded additional {len(self.documents)} documents from .txt files.")
 
 
 
 
50
 
51
- def is_query_malicious(self, query):
52
- # Use the pre-trained model to check if the query has malicious intent
53
- result = self.malicious_detector(query)[0]
54
- label = result['label']
55
- score = result['score']
56
 
57
- # Consider the query malicious if the sentiment is negative with high confidence
58
- if label == "NEGATIVE" and score > 0.8:
59
- print(f"Warning: Malicious query detected - Confidence: {score:.4f}")
60
- return True
61
- return False
 
62
 
63
- def search_documents(self, query):
64
- if self.is_query_malicious(query):
65
- return [{"document": "ANOMALY: Query blocked due to detected malicious intent.", "similarity": 0.0}]
66
 
67
- # Use fuzzy matching for normal queries
68
- matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
69
 
70
- if not matches:
71
- return [{"document": "No matching documents found.", "similarity": 0.0}]
72
 
73
- return [{"document": match[:100] + "..."} for match in matches]
 
74
 
75
- # Test the system with normal and malicious queries
76
- def test_document_search():
77
- searcher = DocumentSearcher()
78
 
79
- # Load the IMDB movie reviews
80
- searcher.load_imdb_data()
 
 
81
 
82
- # Load additional .txt files
83
- searcher.load_txt_files()
 
 
 
84
 
85
- # Perform a normal query
86
- normal_query = "This movie had great acting and a compelling storyline."
87
- normal_results = searcher.search_documents(normal_query)
88
- print("Normal Query Results:")
89
- for result in normal_results:
90
- print(f"Document: {result['document']}")
91
 
92
- # Perform a query injection attack
93
- malicious_query = "DROP TABLE reviews; SELECT * FROM confidential_data;"
94
- attack_results = searcher.search_documents(malicious_query)
95
- print("\nMalicious Query Results:")
96
- for result in attack_results:
97
- print(f"Document: {result['document']}")
98
 
99
  if __name__ == "__main__":
100
- test_document_search()
 
 
 
 
1
  from transformers import pipeline
2
+ from difflib import get_close_matches
3
+ from pathlib import Path
4
+ import os
5
 
6
+
7
+ class BadQueryDetector:
8
  def __init__(self):
9
+ self.detector = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
10
+
11
+ def is_bad_query(self, query):
12
+ result = self.detector(query)[0]
13
+ label = result["label"]
14
+ score = result["score"]
15
+ # Mark queries as malicious or bad if negative sentiment with high confidence
16
+ if label == "NEGATIVE" and score > 0.8:
17
+ print(f"Detected malicious query with high confidence ({score:.4f}): {query}")
18
+ return True
19
+ return False
20
 
 
 
 
 
 
21
 
22
+ class QueryTransformer:
23
+ def transform_query(self, query):
24
+ # Simple transformation example: rephrasing and clarifying
25
+ # In practice, this could involve more sophisticated models like T5
26
+ if "DROP TABLE" in query or "SELECT *" in query:
27
+ return "Your query appears to contain SQL injection elements. Please rephrase."
28
+ # Add more sophisticated handling here
29
+ return query
30
 
 
 
 
 
31
 
32
+ class DocumentRetriever:
33
+ def __init__(self):
34
+ self.documents = []
35
+
36
+ def load_documents(self, source_dir):
37
+ data_dir = Path(source_dir)
38
+ if not data_dir.exists():
39
+ print(f"Source directory not found: {source_dir}")
40
+ return
41
 
42
+ for file in data_dir.glob("*.txt"):
43
+ with open(file, "r", encoding="utf-8") as f:
44
+ self.documents.append(f.read())
45
 
46
+ print(f"Loaded {len(self.documents)} documents.")
47
 
48
+ def retrieve(self, query):
49
+ matches = get_close_matches(query, self.documents, n=5, cutoff=0.3)
50
+ return matches if matches else ["No matching documents found."]
 
51
 
 
 
 
52
 
53
+ class SemanticResponseGenerator:
54
+ def __init__(self):
55
+ self.generator = pipeline("text-generation", model="gpt2")
56
 
57
+ def generate_response(self, retrieved_docs):
58
+ # Generate a semantic response using retrieved documents
59
+ combined_docs = " ".join(retrieved_docs[:2]) # Use top 2 matches for response
60
+ response = self.generator(f"Based on the following information: {combined_docs}", max_length=100)
61
+ return response[0]["generated_text"]
62
 
 
 
 
 
 
63
 
64
+ class DocumentSearchSystem:
65
+ def __init__(self):
66
+ self.detector = BadQueryDetector()
67
+ self.transformer = QueryTransformer()
68
+ self.retriever = DocumentRetriever()
69
+ self.response_generator = SemanticResponseGenerator()
70
 
71
+ def process_query(self, query):
72
+ if self.detector.is_bad_query(query):
73
+ return {"status": "rejected", "message": "Query blocked due to detected malicious intent."}
74
 
75
+ transformed_query = self.transformer.transform_query(query)
76
+ retrieved_docs = self.retriever.retrieve(transformed_query)
77
 
78
+ if "No matching documents found." in retrieved_docs:
79
+ return {"status": "no_results", "message": "No relevant documents found for your query."}
80
 
81
+ response = self.response_generator.generate_response(retrieved_docs)
82
+ return {"status": "success", "response": response}
83
 
 
 
 
84
 
85
+ # Test the enhanced system
86
+ def test_system():
87
+ system = DocumentSearchSystem()
88
+ system.retriever.load_documents("/path/to/documents")
89
 
90
+ # Test with a normal query
91
+ normal_query = "Tell me about great acting performances."
92
+ normal_result = system.process_query(normal_query)
93
+ print("\nNormal Query Result:")
94
+ print(normal_result)
95
 
96
+ # Test with a malicious query
97
+ malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;"
98
+ malicious_result = system.process_query(malicious_query)
99
+ print("\nMalicious Query Result:")
100
+ print(malicious_result)
 
101
 
 
 
 
 
 
 
102
 
103
  if __name__ == "__main__":
104
+ test_system()