import torch, math from pyvi.ViTokenizer import tokenize import re, os, string import pandas as pd import math import numpy as np class BM25: def __init__(self, k1=1.5, b=0.75): self.b = b self.k1 = k1 def fit(self, corpus): """ Fit the various statistics that are required to calculate BM25 ranking score using the corpus given. Parameters ---------- corpus : list[list[str]] Each element in the list represents a document, and each document is a list of the terms. Returns ------- self """ tf = [] df = {} idf = {} doc_len = [] corpus_size = 0 for document in corpus: corpus_size += 1 doc_len.append(len(document)) # compute tf (term frequency) per document frequencies = {} for term in document: term_count = frequencies.get(term, 0) + 1 frequencies[term] = term_count tf.append(frequencies) # compute df (document frequency) per term for term, _ in frequencies.items(): df_count = df.get(term, 0) + 1 df[term] = df_count for term, freq in df.items(): idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5)) self.tf_ = tf self.df_ = df self.idf_ = idf self.doc_len_ = doc_len self.corpus_ = corpus self.corpus_size_ = corpus_size self.avg_doc_len_ = sum(doc_len) / corpus_size return self def search(self, query): scores = [self._score(query, index) for index in range(self.corpus_size_)] return scores def _score(self, query, index): score = 0.0 doc_len = self.doc_len_[index] frequencies = self.tf_[index] for term in query: if term not in frequencies: continue freq = frequencies[term] numerator = self.idf_[term] * freq * (self.k1 + 1) denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_) score += (numerator / denominator) return score class Retrieval: def __init__( self, k=8, model='retrieval/bm25.pt', contexts='retrieval/context.pt', stop_words='retrieval/stopwords.csv', max_len = 400, docs = None ) -> None: self.k = k self.max_len = max_len data = pd.read_csv(stop_words, sep="\t", encoding='utf-8') self.list_stopwords = data['stopwords'] if docs: self.tuning(docs) else: self.bm25 = torch.load(model) self.contexts = torch.load(contexts) def get_context(self, query='Chảy máu chân răng là bệnh gì?'): def clean_text(text): text = re.sub('<.*?>', '', text).strip() text = re.sub('(\s)+', r'\1', text) return text def normalize_text(text): listpunctuation = string.punctuation.replace('_', '') for i in listpunctuation: text = text.replace(i, ' ') return text.lower() def remove_stopword(text): pre_text = [] words = text.split() for word in words: if word not in self.list_stopwords: pre_text.append(word) text2 = ' '.join(pre_text) return text2 def word_segment(sent): sent = tokenize(sent.encode('utf-8').decode('utf-8')) return sent query = clean_text(query) query = word_segment(query) query = remove_stopword(normalize_text(query)) query = query.split() scores = self.bm25.search(query) scores_index = np.argsort(scores) results = [] ss = [] for k in range(1, self.k+1): index = scores_index[-k] result = {'score_bm':scores[index], 'index':index, 'context':self.contexts[index]} results.append(result) ss.append(scores[index]) print("BM25:", ss) return results def split(self, document): document = document.replace('\n', ' ') document = re.sub(' +', ' ', document) sentences = document.split('. ') context_list = [] context = "" length = 0 pre = "" len__ = 0 for sentence in sentences: sentence += '. ' len_ = len(sentence.split()) if length + len_ > self.max_len: context_list.append(context) context = pre length = len__ length += len_ context += sentence pre = sentence len__ = len_ context_list.append(context) self.contexts = context_list if len(context_list) < self.k: self.k = len(context_list) def tuning(self, document): def clean_text(text): text = re.sub('<.*?>', '', text).strip() text = re.sub('(\s)+', r'\1', text) return text def normalize_text(text): listpunctuation = string.punctuation.replace('_', '') for i in listpunctuation: text = text.replace(i, ' ') return text.lower() def remove_stopword(text): pre_text = [] words = text.split() for word in words: if word not in self.list_stopwords: pre_text.append(word) text2 = ' '.join(pre_text) return text2 def word_segment(sent): sent = tokenize(sent.encode('utf-8').decode('utf-8')) return sent self.split(document) docs = [] for content in self.contexts: content = clean_text(content) content = word_segment(content) content = remove_stopword(normalize_text(content)) docs.append(content) print('There is', len(docs), 'contexts') texts = [ [word for word in document.lower().split() if word not in self.list_stopwords] for document in docs ] self.bm25 = BM25() self.bm25.fit(texts)