File size: 6,336 Bytes
74b1bac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
import bm25s
import re
import string
from tqdm import tqdm
from pyvi.ViTokenizer import tokenize
def clean_text(text: str) -> str:
text = re.sub('<.*?>', '', text).strip()
text = text.encode('utf-8', 'ignore').decode('utf-8')
text = re.sub(r'\s+', ' ', text).strip()
return text
def normalize_text(text: str) -> str:
list_punctuation = string.punctuation.replace('_', '')
for punct in list_punctuation:
text = text.replace(punct, ' ')
text = text.lower().strip()
text = re.sub(r'\s+', ' ', text).strip()
return text
def process_text(text: str) -> str:
text = clean_text(text)
text = tokenize(text)
text = normalize_text(text)
return text
def default_preprocessing_func(text: str) -> List[str]:
if type(text) == tuple :
fin_text = [process_text(doc) for doc in tqdm(text)]
elif type(text) == str :
fin_text = process_text(text)
token_corpus = bm25s.tokenize(texts=fin_text, stopwords = "vi", return_ids= False , show_progress=False)
return token_corpus
class BM25SRetriever(BaseRetriever):
"""A toy retriever that contains the top k documents that contain the user query.
This retriever only implements the sync method _get_relevant_documents.
If the retriever were to involve file access or network access, it could benefit
from a native async implementation of `_aget_relevant_documents`.
As usual, with Runnables, there's a default async implementation that's provided
that delegates to the sync implementation running on another thread.
"""
vectorizer: Any
""" BM25S vectorizer."""
docs: List[Document] = Field(repr=False)
"""List of documents to retrieve from."""
k: int = 4
"""Number of top results to return"""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
""" Preprocessing function to use on the text before BM25 vectorization."""
save_directory : Optional[str] = None
""" Directory for saving BM25S index."""
activate_numba: bool = False
"""Accelerate backend"""
class Config:
arbitrary_types_allowed = True
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
save_directory : Optional[str] = save_directory,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts: A list of texts to vectorize.
metadatas: A list of metadata dicts to associate with each text.
bm25s_params: Parameters to pass to the BM25s vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25SRetriever instance.
"""
try:
from bm25s import BM25
except ImportError:
raise ImportError(
"Could not import bm25s, please install with `pip install "
"bm25s`."
)
bm25_params = bm25_params or {}
if save_directory and Path(save_directory).exists():
try:
vectorizer = BM25.load(save_directory)
except Exception as e:
print(f"Failed to load BM25 index from {save_directory}: {e}")
print("Proceeding with indexing from scratch.")
texts_processed = preprocess_func(texts)
vectorizer = BM25(**bm25_params)
vectorizer.index(texts_processed)
if save_directory:
vectorizer.save(save_directory)
else:
texts_processed = preprocess_func(texts)
vectorizer = BM25(**bm25_params)
vectorizer.index(texts_processed)
if save_directory:
vectorizer.save(save_directory)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(
vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, save_directory=save_directory, **kwargs
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25SRetriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents: A list of Documents to vectorize.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
processed_query = self.preprocess_func(query)
if self.activate_numba :
self.vectorizer.activate_numba_scorer()
return_docs = self.vectorizer.retrieve(processed_query, k=self.k, backend_selection="numba")
return [self.docs[i] for i in return_docs.documents[0]]
else :
return_docs, scores = self.vectorizer.retrieve(processed_query, self.docs, k = self.k)
return [return_docs[0, i] for i in range(return_docs.shape[1])]
|