talexm commited on
Commit
ed26242
1 Parent(s): 1d44212

news-data retrival

Browse files
Files changed (1) hide show
  1. rag_sec/document_retriver.py +12 -36
rag_sec/document_retriver.py CHANGED
@@ -1,47 +1,23 @@
1
  import faiss
2
  from sklearn.feature_extraction.text import TfidfVectorizer
3
  import numpy as np
 
4
 
5
  class DocumentRetriever:
6
  def __init__(self):
7
  self.documents = []
8
- self.vectorizer = TfidfVectorizer()
9
- self.index = None
10
 
11
- def load_documents(self, source_dir):
12
- from pathlib import Path
 
 
 
 
13
 
14
- data_dir = Path(source_dir)
15
- if not data_dir.exists():
16
- print(f"Source directory not found: {source_dir}")
17
- return
18
-
19
- for file in data_dir.glob("*.txt"):
20
- with open(file, "r", encoding="utf-8") as f:
21
- self.documents.append(f.read())
22
-
23
- print(f"Loaded {len(self.documents)} documents.")
24
-
25
- # Create the FAISS index
26
- self._build_index()
27
-
28
- def _build_index(self):
29
- # Generate TF-IDF vectors for documents
30
- doc_vectors = self.vectorizer.fit_transform(self.documents).toarray()
31
-
32
- # Create FAISS index
33
- self.index = faiss.IndexFlatL2(doc_vectors.shape[1])
34
- self.index.add(doc_vectors.astype(np.float32))
35
-
36
- def retrieve(self, query, top_k=5):
37
- if not self.index:
38
  return ["Document retrieval is not initialized."]
 
 
39
 
40
- # Vectorize the query
41
- query_vector = self.vectorizer.transform([query]).toarray().astype(np.float32)
42
-
43
- # Perform FAISS search
44
- distances, indices = self.index.search(query_vector, top_k)
45
-
46
- # Return matching documents
47
- return [self.documents[i] for i in indices[0] if i < len(self.documents)]
 
1
  import faiss
2
  from sklearn.feature_extraction.text import TfidfVectorizer
3
  import numpy as np
4
+ from sklearn.datasets import fetch_20newsgroups
5
 
6
  class DocumentRetriever:
7
  def __init__(self):
8
  self.documents = []
 
 
9
 
10
+ def load_documents(self):
11
+ """Load 20 Newsgroups dataset."""
12
+ newsgroups_data = fetch_20newsgroups(subset='all')
13
+ self.documents = newsgroups_data.data
14
+ if not self.documents:
15
+ print("No documents loaded!")
16
 
17
+ def retrieve(self, query):
18
+ """Retrieve documents related to the query."""
19
+ if not self.documents:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return ["Document retrieval is not initialized."]
21
+ # Simple keyword match (can replace with advanced semantic similarity later)
22
+ return [doc for doc in self.documents if query.lower() in doc.lower()]
23