from __future__ import annotations from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever import bm25s import re import string from tqdm import tqdm from pyvi.ViTokenizer import tokenize def clean_text(text: str) -> str: text = re.sub('<.*?>', '', text).strip() text = text.encode('utf-8', 'ignore').decode('utf-8') text = re.sub(r'\s+', ' ', text).strip() return text def normalize_text(text: str) -> str: list_punctuation = string.punctuation.replace('_', '') for punct in list_punctuation: text = text.replace(punct, ' ') text = text.lower().strip() text = re.sub(r'\s+', ' ', text).strip() return text def process_text(text: str) -> str: text = clean_text(text) text = tokenize(text) text = normalize_text(text) return text def default_preprocessing_func(text: str) -> List[str]: if type(text) == tuple : fin_text = [process_text(doc) for doc in tqdm(text)] elif type(text) == str : fin_text = process_text(text) token_corpus = bm25s.tokenize(texts=fin_text, stopwords = "vi", return_ids= False , show_progress=False) return token_corpus class BM25SRetriever(BaseRetriever): """A toy retriever that contains the top k documents that contain the user query. This retriever only implements the sync method _get_relevant_documents. If the retriever were to involve file access or network access, it could benefit from a native async implementation of `_aget_relevant_documents`. As usual, with Runnables, there's a default async implementation that's provided that delegates to the sync implementation running on another thread. """ vectorizer: Any """ BM25S vectorizer.""" docs: List[Document] = Field(repr=False) """List of documents to retrieve from.""" k: int = 4 """Number of top results to return""" preprocess_func: Callable[[str], List[str]] = default_preprocessing_func """ Preprocessing function to use on the text before BM25 vectorization.""" save_directory : Optional[str] = None """ Directory for saving BM25S index.""" activate_numba: bool = False """Accelerate backend""" class Config: arbitrary_types_allowed = True @classmethod def from_texts( cls, texts: Iterable[str], metadatas: Optional[Iterable[dict]] = None, bm25_params: Optional[Dict[str, Any]] = None, save_directory : Optional[str] = save_directory, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25SRetriever: """ Create a BM25Retriever from a list of texts. Args: texts: A list of texts to vectorize. metadatas: A list of metadata dicts to associate with each text. bm25s_params: Parameters to pass to the BM25s vectorizer. preprocess_func: A function to preprocess each text before vectorization. **kwargs: Any other arguments to pass to the retriever. Returns: A BM25SRetriever instance. """ try: from bm25s import BM25 except ImportError: raise ImportError( "Could not import bm25s, please install with `pip install " "bm25s`." ) bm25_params = bm25_params or {} if save_directory and Path(save_directory).exists(): try: vectorizer = BM25.load(save_directory) except Exception as e: print(f"Failed to load BM25 index from {save_directory}: {e}") print("Proceeding with indexing from scratch.") texts_processed = preprocess_func(texts) vectorizer = BM25(**bm25_params) vectorizer.index(texts_processed) if save_directory: vectorizer.save(save_directory) else: texts_processed = preprocess_func(texts) vectorizer = BM25(**bm25_params) vectorizer.index(texts_processed) if save_directory: vectorizer.save(save_directory) metadatas = metadatas or ({} for _ in texts) docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] return cls( vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, save_directory=save_directory, **kwargs ) @classmethod def from_documents( cls, documents: Iterable[Document], *, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25SRetriever: """ Create a BM25Retriever from a list of Documents. Args: documents: A list of Documents to vectorize. bm25_params: Parameters to pass to the BM25 vectorizer. preprocess_func: A function to preprocess each text before vectorization. **kwargs: Any other arguments to pass to the retriever. Returns: A BM25Retriever instance. """ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) return cls.from_texts( texts=texts, bm25_params=bm25_params, metadatas=metadatas, preprocess_func=preprocess_func, **kwargs, ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: processed_query = self.preprocess_func(query) if self.activate_numba : self.vectorizer.activate_numba_scorer() return_docs = self.vectorizer.retrieve(processed_query, k=self.k, backend_selection="numba") return [self.docs[i] for i in return_docs.documents[0]] else : return_docs, scores = self.vectorizer.retrieve(processed_query, self.docs, k = self.k) return [return_docs[0, i] for i in range(return_docs.shape[1])]