chagu-demo / rag_sec /document_retriver.py
talexm
update
6dd2090
raw
history blame
711 Bytes
from sklearn.datasets import fetch_20newsgroups
class DocumentRetriever:
def __init__(self):
self.documents = []
def load_documents(self, subset_size=500):
"""Load a subset of 20 Newsgroups dataset."""
newsgroups_data = fetch_20newsgroups(subset='all')
self.documents = newsgroups_data.data[:subset_size] # Load only the first `subset_size` documents
print(f"Loaded {len(self.documents)} documents.")
def retrieve(self, query):
"""Retrieve documents related to the query."""
if not self.documents:
return ["Document retrieval is not initialized."]
return [doc for doc in self.documents if query.lower() in doc.lower()]