File size: 6,336 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations

from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
import bm25s
import re
import string
from tqdm import tqdm
from pyvi.ViTokenizer import tokenize

def clean_text(text: str) -> str:
    text = re.sub('<.*?>', '', text).strip()
    text = text.encode('utf-8', 'ignore').decode('utf-8')
    
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def normalize_text(text: str) -> str:
    list_punctuation = string.punctuation.replace('_', '')
    
    for punct in list_punctuation:
        text = text.replace(punct, ' ')
    
    text = text.lower().strip()
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def process_text(text: str) -> str:
    text = clean_text(text)
    text = tokenize(text)
    text = normalize_text(text)
    return text

def default_preprocessing_func(text: str) -> List[str]:
    if type(text) == tuple :
        fin_text = [process_text(doc) for doc in tqdm(text)]
    elif type(text) == str :
        fin_text = process_text(text)
    token_corpus = bm25s.tokenize(texts=fin_text, stopwords = "vi", return_ids= False , show_progress=False)
    return token_corpus


class BM25SRetriever(BaseRetriever):
    """A toy retriever that contains the top k documents that contain the user query.

    This retriever only implements the sync method _get_relevant_documents.

    If the retriever were to involve file access or network access, it could benefit
    from a native async implementation of `_aget_relevant_documents`.

    As usual, with Runnables, there's a default async implementation that's provided
    that delegates to the sync implementation running on another thread.
    """
    vectorizer: Any
    """ BM25S vectorizer."""
    docs: List[Document] = Field(repr=False)
    """List of documents to retrieve from."""
    k: int = 4
    """Number of top results to return"""
    preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
    """ Preprocessing function to use on the text before BM25 vectorization."""
    save_directory : Optional[str] = None
    """ Directory for saving BM25S index."""
    activate_numba: bool = False
    """Accelerate backend"""
    class Config:
        arbitrary_types_allowed = True
    @classmethod
    def from_texts(
        cls,
        texts: Iterable[str],
        metadatas: Optional[Iterable[dict]] = None,
        bm25_params: Optional[Dict[str, Any]] = None,
        save_directory : Optional[str] = save_directory,
        preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
        **kwargs: Any,
    ) -> BM25SRetriever:
        """
        Create a BM25Retriever from a list of texts.
        Args:
            texts: A list of texts to vectorize.
            metadatas: A list of metadata dicts to associate with each text.
            bm25s_params: Parameters to pass to the BM25s vectorizer.
            preprocess_func: A function to preprocess each text before vectorization.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25SRetriever instance.
        """
        try:
            from bm25s import BM25
        except ImportError:
            raise ImportError(
                "Could not import bm25s, please install with `pip install "
                "bm25s`."
            )
        bm25_params = bm25_params or {}
        if save_directory and Path(save_directory).exists():
            try:
                vectorizer = BM25.load(save_directory)
            except Exception as e:
                print(f"Failed to load BM25 index from {save_directory}: {e}")
                print("Proceeding with indexing from scratch.")
                texts_processed = preprocess_func(texts)
                vectorizer = BM25(**bm25_params)
                vectorizer.index(texts_processed)
                if save_directory:
                    vectorizer.save(save_directory)

        else:
            texts_processed = preprocess_func(texts)
            vectorizer = BM25(**bm25_params)
            vectorizer.index(texts_processed)
            if save_directory:
                vectorizer.save(save_directory)

        metadatas = metadatas or ({} for _ in texts)
        docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
        return cls(
            vectorizer=vectorizer, docs=docs, preprocess_func=preprocess_func, save_directory=save_directory, **kwargs
        )

    @classmethod
    def from_documents(
        cls,
        documents: Iterable[Document],
        *,
        bm25_params: Optional[Dict[str, Any]] = None,
        preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,

        **kwargs: Any,
    ) -> BM25SRetriever:
        """
        Create a BM25Retriever from a list of Documents.
        Args:
            documents: A list of Documents to vectorize.
            bm25_params: Parameters to pass to the BM25 vectorizer.
            preprocess_func: A function to preprocess each text before vectorization.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25Retriever instance.
        """
        texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
        return cls.from_texts(
            texts=texts,
            bm25_params=bm25_params,
            metadatas=metadatas,
            preprocess_func=preprocess_func,
            **kwargs,
        )

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        processed_query = self.preprocess_func(query)
        if self.activate_numba : 
          self.vectorizer.activate_numba_scorer()
          return_docs = self.vectorizer.retrieve(processed_query, k=self.k, backend_selection="numba")
          return [self.docs[i] for i in return_docs.documents[0]]
        else :  
          return_docs, scores = self.vectorizer.retrieve(processed_query, self.docs, k = self.k)
          return [return_docs[0, i] for i in range(return_docs.shape[1])]