Spaces:
Runtime error
Runtime error
import torch, math | |
from pyvi.ViTokenizer import tokenize | |
import re, os, string | |
import pandas as pd | |
import math | |
import numpy as np | |
class BM25: | |
def __init__(self, k1=1.5, b=0.75): | |
self.b = b | |
self.k1 = k1 | |
def fit(self, corpus): | |
""" | |
Fit the various statistics that are required to calculate BM25 ranking | |
score using the corpus given. | |
Parameters | |
---------- | |
corpus : list[list[str]] | |
Each element in the list represents a document, and each document | |
is a list of the terms. | |
Returns | |
------- | |
self | |
""" | |
tf = [] | |
df = {} | |
idf = {} | |
doc_len = [] | |
corpus_size = 0 | |
for document in corpus: | |
corpus_size += 1 | |
doc_len.append(len(document)) | |
# compute tf (term frequency) per document | |
frequencies = {} | |
for term in document: | |
term_count = frequencies.get(term, 0) + 1 | |
frequencies[term] = term_count | |
tf.append(frequencies) | |
# compute df (document frequency) per term | |
for term, _ in frequencies.items(): | |
df_count = df.get(term, 0) + 1 | |
df[term] = df_count | |
for term, freq in df.items(): | |
idf[term] = math.log(1 + (corpus_size - freq + 0.5) / (freq + 0.5)) | |
self.tf_ = tf | |
self.df_ = df | |
self.idf_ = idf | |
self.doc_len_ = doc_len | |
self.corpus_ = corpus | |
self.corpus_size_ = corpus_size | |
self.avg_doc_len_ = sum(doc_len) / corpus_size | |
return self | |
def search(self, query): | |
scores = [self._score(query, index) for index in range(self.corpus_size_)] | |
return scores | |
def _score(self, query, index): | |
score = 0.0 | |
doc_len = self.doc_len_[index] | |
frequencies = self.tf_[index] | |
for term in query: | |
if term not in frequencies: | |
continue | |
freq = frequencies[term] | |
numerator = self.idf_[term] * freq * (self.k1 + 1) | |
denominator = freq + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_len_) | |
score += (numerator / denominator) | |
return score | |
class Retrieval: | |
def __init__( | |
self, k=8, | |
model='retrieval/bm25.pt', | |
contexts='retrieval/context.pt', | |
stop_words='retrieval/stopwords.csv', | |
max_len = 400, | |
docs = None | |
) -> None: | |
self.k = k | |
self.max_len = max_len | |
data = pd.read_csv(stop_words, sep="\t", encoding='utf-8') | |
self.list_stopwords = data['stopwords'] | |
if docs: | |
self.tuning(docs) | |
else: | |
self.bm25 = torch.load(model) | |
self.contexts = torch.load(contexts) | |
def get_context(self, query='Chảy máu chân răng là bệnh gì?'): | |
def clean_text(text): | |
text = re.sub('<.*?>', '', text).strip() | |
text = re.sub('(\s)+', r'\1', text) | |
return text | |
def normalize_text(text): | |
listpunctuation = string.punctuation.replace('_', '') | |
for i in listpunctuation: | |
text = text.replace(i, ' ') | |
return text.lower() | |
def remove_stopword(text): | |
pre_text = [] | |
words = text.split() | |
for word in words: | |
if word not in self.list_stopwords: | |
pre_text.append(word) | |
text2 = ' '.join(pre_text) | |
return text2 | |
def word_segment(sent): | |
sent = tokenize(sent.encode('utf-8').decode('utf-8')) | |
return sent | |
query = clean_text(query) | |
query = word_segment(query) | |
query = remove_stopword(normalize_text(query)) | |
query = query.split() | |
scores = self.bm25.search(query) | |
scores_index = np.argsort(scores) | |
results = [] | |
for k in range(1, self.k+1): | |
index = scores_index[-k] | |
result = {'score':scores[index], 'index':index, 'context':self.contexts[index]} | |
results.append(result) | |
return results | |
def split(self, document): | |
document = document.replace('\n', ' ') | |
document = re.sub(' +', ' ', document) | |
sentences = document.split('. ') | |
context_list = [] | |
context = "" | |
length = 0 | |
pre = "" | |
len__ = 0 | |
for sentence in sentences: | |
sentence += '. ' | |
len_ = len(sentence.split()) | |
if length + len_ > self.max_len: | |
context_list.append(context) | |
context = pre | |
length = len__ | |
length += len_ | |
context += sentence | |
pre = sentence | |
len__ = len_ | |
context_list.append(context) | |
self.contexts = context_list | |
if len(context_list) < self.k: | |
self.k = len(context_list) | |
def tuning(self, document): | |
def clean_text(text): | |
text = re.sub('<.*?>', '', text).strip() | |
text = re.sub('(\s)+', r'\1', text) | |
return text | |
def normalize_text(text): | |
listpunctuation = string.punctuation.replace('_', '') | |
for i in listpunctuation: | |
text = text.replace(i, ' ') | |
return text.lower() | |
def remove_stopword(text): | |
pre_text = [] | |
words = text.split() | |
for word in words: | |
if word not in self.list_stopwords: | |
pre_text.append(word) | |
text2 = ' '.join(pre_text) | |
return text2 | |
def word_segment(sent): | |
sent = tokenize(sent.encode('utf-8').decode('utf-8')) | |
return sent | |
self.split(document) | |
docs = [] | |
for content in self.contexts: | |
content = clean_text(content) | |
content = word_segment(content) | |
content = remove_stopword(normalize_text(content)) | |
docs.append(content) | |
print('There is', len(docs), 'contexts') | |
texts = [ | |
[word for word in document.lower().split() if word not in self.list_stopwords] | |
for document in docs | |
] | |
self.bm25 = BM25() | |
self.bm25.fit(texts) | |