Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import gradio as gr | |
from typing import TypedDict | |
from dataclasses import dataclass | |
import pickle | |
import os | |
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar | |
from collections import Counter | |
import re | |
import nltk | |
nltk.download("stopwords", quiet=True) | |
from nltk.corpus import stopwords as nltk_stopwords | |
from dataclasses import asdict, dataclass | |
import math | |
import os | |
from typing import Iterable, List, Optional, Type | |
import tqdm | |
from nlp4web_codebase.nlp4web_codebase.ir.data_loaders.dm import Document | |
from nlp4web_codebase.nlp4web_codebase.ir.models import BaseRetriever | |
from nlp4web_codebase.nlp4web_codebase.ir.data_loaders.sciq import load_sciq | |
from typing import Type | |
from abc import abstractmethod | |
from nlp4web_codebase.nlp4web_codebase.ir.data_loaders import Split | |
import pytrec_eval | |
import numpy as np | |
from scipy.sparse._csc import csc_matrix | |
LANGUAGE = "english" | |
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall | |
stopwords = set(nltk_stopwords.words(LANGUAGE)) | |
def word_splitting(text: str) -> List[str]: | |
return word_splitter(text.lower()) | |
def lemmatization(words: List[str]) -> List[str]: | |
return words # We ignore lemmatization here for simplicity | |
def simple_tokenize(text: str) -> List[str]: | |
words = word_splitting(text) | |
tokenized = list(filter(lambda w: w not in stopwords, words)) | |
tokenized = lemmatization(tokenized) | |
return tokenized | |
T = TypeVar("T", bound="InvertedIndex") | |
class PostingList: | |
term: str # The term | |
docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting | |
tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting | |
class InvertedIndex: | |
posting_lists: List[PostingList] # docid -> posting_list | |
vocab: Dict[str, int] | |
cid2docid: Dict[str, int] # collection_id -> docid | |
collection_ids: List[str] # docid -> collection_id | |
doc_texts: Optional[List[str]] = None # docid -> document text | |
def save(self, output_dir: str) -> None: | |
os.makedirs(output_dir, exist_ok=True) | |
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: | |
pickle.dump(self, f) | |
def from_saved(cls: Type[T], saved_dir: str) -> T: | |
index = cls( | |
posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None | |
) | |
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: | |
index = pickle.load(f) | |
return index | |
# The output of the counting function: | |
class Counting: | |
posting_lists: List[PostingList] | |
vocab: Dict[str, int] | |
cid2docid: Dict[str, int] | |
collection_ids: List[str] | |
dfs: List[int] # tid -> df | |
dls: List[int] # docid -> doc length | |
avgdl: float | |
nterms: int | |
doc_texts: Optional[List[str]] = None | |
def run_counting( | |
documents: Iterable[Document], | |
tokenize_fn: Callable[[str], List[str]] = simple_tokenize, | |
store_raw: bool = True, # store the document text in doc_texts | |
ndocs: Optional[int] = None, | |
show_progress_bar: bool = True, | |
) -> Counting: | |
"""Counting TFs, DFs, doc_lengths, etc.""" | |
posting_lists: List[PostingList] = [] | |
vocab: Dict[str, int] = {} | |
cid2docid: Dict[str, int] = {} | |
collection_ids: List[str] = [] | |
dfs: List[int] = [] # tid -> df | |
dls: List[int] = [] # docid -> doc length | |
nterms: int = 0 | |
doc_texts: Optional[List[str]] = [] | |
for doc in tqdm.tqdm( | |
documents, | |
desc="Counting", | |
total=ndocs, | |
disable=not show_progress_bar, | |
): | |
if doc.collection_id in cid2docid: | |
continue | |
collection_ids.append(doc.collection_id) | |
docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) | |
toks = tokenize_fn(doc.text) | |
tok2tf = Counter(toks) | |
dls.append(sum(tok2tf.values())) | |
for tok, tf in tok2tf.items(): | |
nterms += tf | |
tid = vocab.get(tok, None) | |
if tid is None: | |
posting_lists.append( | |
PostingList(term=tok, docid_postings=[], tweight_postings=[]) | |
) | |
tid = vocab.setdefault(tok, len(vocab)) | |
posting_lists[tid].docid_postings.append(docid) | |
posting_lists[tid].tweight_postings.append(tf) | |
if tid < len(dfs): | |
dfs[tid] += 1 | |
else: | |
dfs.append(0) | |
if store_raw: | |
doc_texts.append(doc.text) | |
else: | |
doc_texts = None | |
return Counting( | |
posting_lists=posting_lists, | |
vocab=vocab, | |
cid2docid=cid2docid, | |
collection_ids=collection_ids, | |
dfs=dfs, | |
dls=dls, | |
avgdl=sum(dls) / len(dls), | |
nterms=nterms, | |
doc_texts=doc_texts, | |
) | |
sciq = load_sciq() | |
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) | |
"""### BM25 Index""" | |
class BM25Index(InvertedIndex): | |
def tokenize(text: str) -> List[str]: | |
return simple_tokenize(text) | |
def cache_term_weights( | |
posting_lists: List[PostingList], | |
total_docs: int, | |
avgdl: float, | |
dfs: List[int], | |
dls: List[int], | |
k1: float, | |
b: float, | |
) -> None: | |
"""Compute term weights and caching""" | |
N = total_docs | |
for tid, posting_list in enumerate( | |
tqdm.tqdm(posting_lists, desc="Regularizing TFs") | |
): | |
idf = BM25Index.calc_idf(df=dfs[tid], N=N) | |
for i in range(len(posting_list.docid_postings)): | |
docid = posting_list.docid_postings[i] | |
tf = posting_list.tweight_postings[i] | |
dl = dls[docid] | |
regularized_tf = BM25Index.calc_regularized_tf( | |
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b | |
) | |
posting_list.tweight_postings[i] = regularized_tf * idf | |
def calc_regularized_tf( | |
tf: int, dl: float, avgdl: float, k1: float, b: float | |
) -> float: | |
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) | |
def calc_idf(df: int, N: int): | |
return math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
def build_from_documents( | |
cls: Type[BM25Index], | |
documents: Iterable[Document], | |
store_raw: bool = True, | |
output_dir: Optional[str] = None, | |
ndocs: Optional[int] = None, | |
show_progress_bar: bool = True, | |
k1: float = 0.9, | |
b: float = 0.4, | |
) -> BM25Index: | |
# Counting TFs, DFs, doc_lengths, etc.: | |
counting = run_counting( | |
documents=documents, | |
tokenize_fn=BM25Index.tokenize, | |
store_raw=store_raw, | |
ndocs=ndocs, | |
show_progress_bar=show_progress_bar, | |
) | |
# Compute term weights and caching: | |
posting_lists = counting.posting_lists | |
total_docs = len(counting.cid2docid) | |
BM25Index.cache_term_weights( | |
posting_lists=posting_lists, | |
total_docs=total_docs, | |
avgdl=counting.avgdl, | |
dfs=counting.dfs, | |
dls=counting.dls, | |
k1=k1, | |
b=b, | |
) | |
# Assembly and save: | |
index = BM25Index( | |
posting_lists=posting_lists, | |
vocab=counting.vocab, | |
cid2docid=counting.cid2docid, | |
collection_ids=counting.collection_ids, | |
doc_texts=counting.doc_texts, | |
) | |
return index | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
) | |
bm25_index.save("output/bm25_index") | |
"""### BM25 Retriever""" | |
class BaseInvertedIndexRetriever(BaseRetriever): | |
def index_class(self) -> Type[InvertedIndex]: | |
pass | |
def __init__(self, index_dir: str) -> None: | |
self.index = self.index_class.from_saved(index_dir) | |
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: | |
toks = self.index.tokenize(query) | |
target_docid = self.index.cid2docid[cid] | |
term_weights = {} | |
for tok in toks: | |
if tok not in self.index.vocab: | |
continue | |
tid = self.index.vocab[tok] | |
posting_list = self.index.posting_lists[tid] | |
for docid, tweight in zip( | |
posting_list.docid_postings, posting_list.tweight_postings | |
): | |
if docid == target_docid: | |
term_weights[tok] = tweight | |
break | |
return term_weights | |
def score(self, query: str, cid: str) -> float: | |
return sum(self.get_term_weights(query=query, cid=cid).values()) | |
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: | |
toks = self.index.tokenize(query) | |
docid2score: Dict[int, float] = {} | |
for tok in toks: | |
if tok not in self.index.vocab: | |
continue | |
tid = self.index.vocab[tok] | |
posting_list = self.index.posting_lists[tid] | |
for docid, tweight in zip( | |
posting_list.docid_postings, posting_list.tweight_postings | |
): | |
docid2score.setdefault(docid, 0) | |
docid2score[docid] += tweight | |
docid2score = dict( | |
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] | |
) | |
return { | |
self.index.collection_ids[docid]: score | |
for docid, score in docid2score.items() | |
} | |
class BM25Retriever(BaseInvertedIndexRetriever): | |
def index_class(self) -> Type[BM25Index]: | |
return BM25Index | |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") | |
def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: | |
metric = "map_cut_10" | |
qrels = sciq.get_qrels_dict(split) | |
evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) | |
qps = evaluator.evaluate(rankings) | |
return float(np.mean([qp[metric] for qp in qps.values()])) | |
"""Example of using the pre-requisite code:""" | |
# Loading dataset: | |
sciq = load_sciq() | |
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) | |
# Building BM25 index and save: | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True | |
) | |
bm25_index.save("output/bm25_index") | |
# Loading index and use BM25 retriever to retrieve: | |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking | |
plots_b: Dict[str, List[float]] = { | |
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], | |
"Y": [] | |
} | |
plots_k1: Dict[str, List[float]] = { | |
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], | |
"Y": [] | |
} | |
## YOUR_CODE_STARTS_HERE | |
# Two steps should be involved: | |
# Step 1. Fix k1 value to the default one 0.9, | |
# go through all the candidate b values (0, 0.1, ..., 1.0), | |
# and record in plots_b["Y"] the corresponding performances obtained via evaluate_map; | |
# Step 2. Fix b to the best one in step 1. and do the same for k1. | |
# Hint (on using the pre-requisite code): | |
# - One can use the loaded sciq dataset directly (loaded in the pre-requisite code); | |
# - One can build bm25_index with `BM25Index.build_from_documents`; | |
# - One can use BM25Retriever to load the index and perform retrieval on the dev queries | |
# (dev queries can be obtained via sciq.get_split_queries(Split.dev)) | |
k1 = 0.9 | |
b_list = [] | |
for b in plots_b["X"]: | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
k1=k1, | |
b=b | |
) | |
bm25_index.save("output/bm25_index") | |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
rankings = {} | |
for query in sciq.get_split_queries(Split.dev): | |
ranking = bm25_retriever.retrieve(query=query.text) | |
rankings[query.query_id] = ranking | |
optimized_map = evaluate_map(rankings, split=Split.dev) | |
b_list.append(optimized_map) | |
plots_b["Y"] = b_list | |
b = plots_b["X"][np.argmax(plots_b["Y"])] | |
k1_list = [] | |
for k1 in plots_k1["X"]: | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
k1=k1, | |
b=b | |
) | |
bm25_index.save("output/bm25_index") | |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
rankings = {} | |
for query in sciq.get_split_queries(Split.dev): | |
ranking = bm25_retriever.retrieve(query=query.text) | |
rankings[query.query_id] = ranking | |
optimized_map = evaluate_map(rankings, split=Split.dev) | |
k1_list.append(optimized_map) | |
plots_k1["Y"] = k1_list | |
"""Let's check the effectiveness gain on test after this tuning on dev""" | |
default_map = 0.7849 | |
best_b = plots_b["X"][np.argmax(plots_b["Y"])] | |
best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
k1=best_k1, | |
b=best_b | |
) | |
bm25_index.save("output/bm25_index") | |
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") | |
rankings = {} | |
for query in sciq.get_split_queries(Split.test): # note this is now on test | |
ranking = bm25_retriever.retrieve(query=query.text) | |
rankings[query.query_id] = ranking | |
optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test | |
print(default_map, optimized_map) | |
"""## TASK2.2: implement `CSCBM25Index` (4 points) | |
Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix. | |
""" | |
class CSCInvertedIndex: | |
posting_lists_matrix: csc_matrix # docid -> posting_list | |
vocab: Dict[str, int] | |
cid2docid: Dict[str, int] # collection_id -> docid | |
collection_ids: List[str] # docid -> collection_id | |
doc_texts: Optional[List[str]] = None # docid -> document text | |
def save(self, output_dir: str) -> None: | |
os.makedirs(output_dir, exist_ok=True) | |
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: | |
pickle.dump(self, f) | |
def from_saved(cls: Type[T], saved_dir: str) -> T: | |
index = cls( | |
posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None | |
) | |
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: | |
index = pickle.load(f) | |
return index | |
class CSCBM25Index(CSCInvertedIndex): | |
def tokenize(text: str) -> List[str]: | |
return simple_tokenize(text) | |
def cache_term_weights( | |
posting_lists: List[PostingList], | |
total_docs: int, | |
avgdl: float, | |
dfs: List[int], | |
dls: List[int], | |
k1: float, | |
b: float, | |
) -> csc_matrix: | |
"""Compute term weights and caching""" | |
## YOUR_CODE_STARTS_HERE | |
data = [] | |
indices = [] | |
indptr = [] | |
N = total_docs | |
for tid, posting_list in enumerate( | |
tqdm.tqdm(posting_lists, desc="Regularizing TFs") | |
): | |
if indptr == []: | |
indptr.append(0) | |
#if dfs[tid] != len(posting_list.docid_postings): | |
# print(dfs[tid], ", ", len(posting_list.docid_postings)) | |
#if dfs[tid] == 0: | |
# print(posting_list.docid_postings[0]) | |
indptr.append(indptr[-1] + len(posting_list.docid_postings)) | |
idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N) | |
for i in range(len(posting_list.docid_postings)): | |
docid = posting_list.docid_postings[i] | |
indices.append(docid) | |
tf = posting_list.tweight_postings[i] | |
dl = dls[docid] | |
regularized_tf = CSCBM25Index.calc_regularized_tf( | |
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b | |
) | |
tf_idf = regularized_tf * idf | |
data.append(tf_idf) | |
posting_lists_matrix = csc_matrix((data, indices, indptr)).astype(np.float32) | |
print(posting_lists_matrix.shape) | |
return posting_lists_matrix | |
## YOUR_CODE_ENDS_HERE | |
def calc_regularized_tf( | |
tf: int, dl: float, avgdl: float, k1: float, b: float | |
) -> float: | |
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) | |
def calc_idf(df: int, N: int): | |
return math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
def build_from_documents( | |
cls: Type[CSCBM25Index], | |
documents: Iterable[Document], | |
store_raw: bool = True, | |
output_dir: Optional[str] = None, | |
ndocs: Optional[int] = None, | |
show_progress_bar: bool = True, | |
k1: float = 0.9, | |
b: float = 0.4, | |
) -> CSCBM25Index: | |
# Counting TFs, DFs, doc_lengths, etc.: | |
counting = run_counting( | |
documents=documents, | |
tokenize_fn=CSCBM25Index.tokenize, | |
store_raw=store_raw, | |
ndocs=ndocs, | |
show_progress_bar=show_progress_bar, | |
) | |
# Compute term weights and caching: | |
posting_lists = counting.posting_lists | |
total_docs = len(counting.cid2docid) | |
posting_lists_matrix = CSCBM25Index.cache_term_weights( | |
posting_lists=posting_lists, | |
total_docs=total_docs, | |
avgdl=counting.avgdl, | |
dfs=counting.dfs, | |
dls=counting.dls, | |
k1=k1, | |
b=b, | |
) | |
# Assembly and save: | |
index = CSCBM25Index( | |
posting_lists_matrix=posting_lists_matrix, | |
vocab=counting.vocab, | |
cid2docid=counting.cid2docid, | |
collection_ids=counting.collection_ids, | |
doc_texts=counting.doc_texts, | |
) | |
return index | |
csc_bm25_index = CSCBM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
k1=best_k1, | |
b=best_b | |
) | |
csc_bm25_index.save("output/csc_bm25_index") | |
class BaseCSCInvertedIndexRetriever(BaseRetriever): | |
def index_class(self) -> Type[CSCInvertedIndex]: | |
pass | |
def __init__(self, index_dir: str) -> None: | |
self.index = self.index_class.from_saved(index_dir) | |
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: | |
## YOUR_CODE_STARTS_HERE | |
toks = CSCBM25Index.tokenize(query) | |
target_docid = self.index.cid2docid[cid] | |
term_weights = {} | |
for tok in toks: | |
if tok not in self.index.vocab: | |
continue | |
tid = self.index.vocab[tok] | |
weight = self.index.posting_lists_matrix[target_docid, tid] | |
if weight == 0: | |
continue | |
term_weights[tok] = weight | |
return term_weights | |
## YOUR_CODE_ENDS_HERE | |
def score(self, query: str, cid: str) -> float: | |
return sum(self.get_term_weights(query=query, cid=cid).values()) | |
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: | |
## YOUR_CODE_STARTS_HERE | |
toks = CSCBM25Index.tokenize(query) | |
docid2score: Dict[int, float] = {} | |
for tok in toks: | |
if tok not in self.index.vocab: | |
continue | |
tid = self.index.vocab[tok] | |
posting_list = self.index.posting_lists_matrix.getcol(tid) | |
indices = posting_list.indices | |
weights = posting_list.data | |
for docid, tweight in zip(indices, weights): | |
docid2score.setdefault(docid, 0) | |
docid2score[docid] += tweight | |
docid2score = dict( | |
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] | |
) | |
ranking = { | |
self.index.collection_ids[docid]: score | |
for docid, score in docid2score.items() | |
} | |
return ranking | |
## YOUR_CODE_ENDS_HERE | |
class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): | |
def index_class(self) -> Type[CSCBM25Index]: | |
return CSCBM25Index | |
class Hit(TypedDict): | |
cid: str | |
score: float | |
text: str | |
demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable | |
return_type = List[Hit] | |
## YOUR_CODE_STARTS_HERE | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
k1=best_k1, | |
b=best_b | |
) | |
bm25_index.save("output/bm25_index") | |
def search(query: str) -> List[Hit]: | |
bm25_index = BM25Retriever(index_dir="output/bm25_index") | |
result = bm25_index.retrieve(query) | |
l : return_type = [] | |
for cid, score in result.items(): | |
docid = bm25_index.index.cid2docid[cid] | |
text = bm25_index.index.doc_texts[docid] | |
l.append(Hit(cid=cid, score=score, text=text)) | |
return l | |
demo = gr.Interface( | |
fn=search, | |
inputs=["text"], | |
outputs=["text"], | |
) | |
## YOUR_CODE_ENDS_HERE | |
demo.launch() |