VDT / retrieval /retrieval.py
CreatorPhan's picture
Upload 14 files
e011405
raw
history blame
6.37 kB
import torch, math
from pyvi.ViTokenizer import tokenize
import re, os, string
import pandas as pd
import math
import numpy as np
class BM25:
def __init__(self, k1=1.5, b=0.75):
self.b = b
self.k1 = k1
def fit(self, corpus):
"""
Fit the various statistics that are required to calculate BM25 ranking
score using the corpus given.
Parameters
----------
corpus : list[list[str]]
Each element in the list represents a document, and each document
is a list of the terms.
Returns
-------
self
"""
tf = []
df = {}
idf = {}
doc_len = []
corpus_size = 0
for document in corpus:
corpus_size += 1
doc_len.append(len(document))
# compute tf (term frequency) per document
frequencies = {}
for term in document:
term_count = frequencies.get(term, 0) + 1
frequencies[term] = term_count
tf.append(frequencies)
# compute df (document frequency) per term
for term, _ in frequencies.items():
df_count = df.get(term, 0) + 1
df[term] = df_count
for term, freq in df.items():
idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5))
self.tf_ = tf
self.df_ = df
self.idf_ = idf
self.doc_len_ = doc_len
self.corpus_ = corpus
self.corpus_size_ = corpus_size
self.avg_doc_len_ = sum(doc_len) / corpus_size
return self
def search(self, query):
scores = [self._score(query, index) for index in range(self.corpus_size_)]
return scores
def _score(self, query, index):
score = 0.0
doc_len = self.doc_len_[index]
frequencies = self.tf_[index]
for term in query:
if term not in frequencies:
continue
freq = frequencies[term]
numerator = self.idf_[term] * freq * (self.k1 + 1)
denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_)
score += (numerator / denominator)
return score
class Retrieval:
def __init__(
self, k=8,
model='retrieval/bm25.pt',
contexts='retrieval/context.pt',
stop_words='retrieval/stopwords.csv',
max_len = 400,
docs = None
) -> None:
self.k = k
self.max_len = max_len
data = pd.read_csv(stop_words, sep="\t", encoding='utf-8')
self.list_stopwords = data['stopwords']
if docs:
self.tuning(docs)
else:
self.bm25 = torch.load(model)
self.contexts = torch.load(contexts)
def get_context(self, query='Chảy máu chân răng là bệnh gì?'):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
query = clean_text(query)
query = word_segment(query)
query = remove_stopword(normalize_text(query))
query = query.split()
scores = self.bm25.search(query)
scores_index = np.argsort(scores)
results = []
for k in range(1, self.k+1):
index = scores_index[-k]
result = {'score':scores[index], 'index':index, 'context':self.contexts[index]}
results.append(result)
return results
def split(self, document):
document = document.replace('\n', ' ')
document = re.sub(' +', ' ', document)
sentences = document.split('. ')
context_list = []
context = ""
length = 0
pre = ""
len__ = 0
for sentence in sentences:
sentence += '. '
len_ = len(sentence.split())
if length + len_ > self.max_len:
context_list.append(context)
context = pre
length = len__
length += len_
context += sentence
pre = sentence
len__ = len_
context_list.append(context)
self.contexts = context_list
if len(context_list) < self.k:
self.k = len(context_list)
def tuning(self, document):
def clean_text(text):
text = re.sub('<.*?>', '', text).strip()
text = re.sub('(\s)+', r'\1', text)
return text
def normalize_text(text):
listpunctuation = string.punctuation.replace('_', '')
for i in listpunctuation:
text = text.replace(i, ' ')
return text.lower()
def remove_stopword(text):
pre_text = []
words = text.split()
for word in words:
if word not in self.list_stopwords:
pre_text.append(word)
text2 = ' '.join(pre_text)
return text2
def word_segment(sent):
sent = tokenize(sent.encode('utf-8').decode('utf-8'))
return sent
self.split(document)
docs = []
for content in self.contexts:
content = clean_text(content)
content = word_segment(content)
content = remove_stopword(normalize_text(content))
docs.append(content)
print('There is', len(docs), 'contexts')
texts = [
[word for word in document.lower().split() if word not in self.list_stopwords]
for document in docs
]
self.bm25 = BM25()
self.bm25.fit(texts)