Graduation / pipelines /BM25 /bm25sretriever.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
6.34 kB
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
import bm25s
import re
import string
from tqdm import tqdm
from pyvi.ViTokenizer import tokenize
def clean_text(text: str) -> str:
text = re.sub('<.*?>', '', text).strip()
text = text.encode('utf-8', 'ignore').decode('utf-8')
text = re.sub(r'\s+', ' ', text).strip()
return text
def normalize_text(text: str) -> str:
list_punctuation = string.punctuation.replace('_', '')
for punct in list_punctuation:
text = text.replace(punct, ' ')
text = text.lower().strip()
text = re.sub(r'\s+', ' ', text).strip()
return text
def process_text(text: str) -> str:
text = clean_text(text)
text = tokenize(text)
text = normalize_text(text)
return text
def default_preprocessing_func(text: str) -> List[str]:
if type(text) == tuple :
fin_text = [process_text(doc) for doc in tqdm(text)]
elif type(text) == str :
fin_text = process_text(text)
token_corpus = bm25s.tokenize(texts=fin_text, stopwords = "vi", return_ids= False , show_progress=False)
return token_corpus
class BM25SRetriever(BaseRetriever):
"""A toy retriever that contains the top k documents that contain the user query.
This retriever only implements the sync method _get_relevant_documents.
If the retriever were to involve file access or network access, it could benefit
from a native async implementation of `_aget_relevant_documents`.
As usual, with Runnables, there's a default async implementation that's provided
that delegates to the sync implementation running on another thread.
"""
vectorizer: Any
""" BM25S vectorizer."""
docs: List[Document] = Field(repr=False)
"""List of documents to retrieve from."""
k: int = 4
"""Number of top results to return"""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
""" Preprocessing function to use on the text before BM25 vectorization."""
save_directory : Optional[str] = None
""" Directory for saving BM25S index."""
activate_numba: bool = False
"""Accelerate backend"""
class Config:
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
save_directory : Optional[str] = save_directory,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts: A list of texts to vectorize.
metadatas: A list of metadata dicts to associate with each text.
bm25s_params: Parameters to pass to the BM25s vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25SRetriever instance.
"""
try:
from bm25s import BM25
except ImportError:
raise ImportError(
"Could not import bm25s, please install with `pip install "
"bm25s`."
)
bm25_params = bm25_params or {}
if save_directory and Path(save_directory).exists():
try:
vectorizer = BM25.load(save_directory)
except Exception as e:
print(f"Failed to load BM25 index from {save_directory}: {e}")
print("Proceeding with indexing from scratch.")
texts_processed = preprocess_func(texts)
vectorizer = BM25(**bm25_params)
vectorizer.index(texts_processed)
if save_directory:
vectorizer.save(save_directory)
else:
texts_processed = preprocess_func(texts)
vectorizer = BM25(**bm25_params)
vectorizer.index(texts_processed)
if save_directory:
vectorizer.save(save_directory)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, save_directory=save_directory, **kwargs
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents: A list of Documents to vectorize.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
processed_query = self.preprocess_func(query)
if self.activate_numba :
self.vectorizer.activate_numba_scorer()
return_docs = self.vectorizer.retrieve(processed_query, k=self.k, backend_selection="numba")
return [self.docs[i] for i in return_docs.documents[0]]
else :
return_docs, scores = self.vectorizer.retrieve(processed_query, self.docs, k = self.k)
return [return_docs[0, i] for i in range(return_docs.shape[1])]