Kurt commited on
Commit
530e7d9
·
1 Parent(s): 8ca8fb7
Files changed (1) hide show
  1. app.py +348 -0
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import TypedDict
3
+ from dataclasses import dataclass
4
+ import pickle
5
+ import os
6
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
7
+ from nlp4web_codebase.ir.data_loaders.dm import Document
8
+ from collections import Counter
9
+ import tqdm
10
+ import re
11
+ import nltk
12
+ nltk.download("stopwords", quiet=True)
13
+ from nltk.corpus import stopwords as nltk_stopwords
14
+
15
+ LANGUAGE = "english"
16
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
17
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
18
+
19
+
20
+ def word_splitting(text: str) -> List[str]:
21
+ return word_splitter(text.lower())
22
+
23
+ def lemmatization(words: List[str]) -> List[str]:
24
+ return words # We ignore lemmatization here for simplicity
25
+
26
+ def simple_tokenize(text: str) -> List[str]:
27
+ words = word_splitting(text)
28
+ tokenized = list(filter(lambda w: w not in stopwords, words))
29
+ tokenized = lemmatization(tokenized)
30
+ return tokenized
31
+
32
+ T = TypeVar("T", bound="InvertedIndex")
33
+
34
+ @dataclass
35
+ class PostingList:
36
+ term: str # The term
37
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
38
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
39
+
40
+
41
+ @dataclass
42
+ class InvertedIndex:
43
+ posting_lists: List[PostingList] # docid -> posting_list
44
+ vocab: Dict[str, int]
45
+ cid2docid: Dict[str, int] # collection_id -> docid
46
+ collection_ids: List[str] # docid -> collection_id
47
+ doc_texts: Optional[List[str]] = None # docid -> document text
48
+
49
+ def save(self, output_dir: str) -> None:
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
52
+ pickle.dump(self, f)
53
+
54
+ @classmethod
55
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
56
+ index = cls(
57
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
58
+ )
59
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
60
+ index = pickle.load(f)
61
+ return index
62
+
63
+
64
+ # The output of the counting function:
65
+ @dataclass
66
+ class Counting:
67
+ posting_lists: List[PostingList]
68
+ vocab: Dict[str, int]
69
+ cid2docid: Dict[str, int]
70
+ collection_ids: List[str]
71
+ dfs: List[int] # tid -> df
72
+ dls: List[int] # docid -> doc length
73
+ avgdl: float
74
+ nterms: int
75
+ doc_texts: Optional[List[str]] = None
76
+
77
+ def run_counting(
78
+ documents: Iterable[Document],
79
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
80
+ store_raw: bool = True, # store the document text in doc_texts
81
+ ndocs: Optional[int] = None,
82
+ show_progress_bar: bool = True,
83
+ ) -> Counting:
84
+ """Counting TFs, DFs, doc_lengths, etc."""
85
+ posting_lists: List[PostingList] = []
86
+ vocab: Dict[str, int] = {}
87
+ cid2docid: Dict[str, int] = {}
88
+ collection_ids: List[str] = []
89
+ dfs: List[int] = [] # tid -> df
90
+ dls: List[int] = [] # docid -> doc length
91
+ nterms: int = 0
92
+ doc_texts: Optional[List[str]] = []
93
+ for doc in tqdm.tqdm(
94
+ documents,
95
+ desc="Counting",
96
+ total=ndocs,
97
+ disable=not show_progress_bar,
98
+ ):
99
+ if doc.collection_id in cid2docid:
100
+ continue
101
+ collection_ids.append(doc.collection_id)
102
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
103
+ toks = tokenize_fn(doc.text)
104
+ tok2tf = Counter(toks)
105
+ dls.append(sum(tok2tf.values()))
106
+ for tok, tf in tok2tf.items():
107
+ nterms += tf
108
+ tid = vocab.get(tok, None)
109
+ if tid is None:
110
+ posting_lists.append(
111
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
112
+ )
113
+ tid = vocab.setdefault(tok, len(vocab))
114
+ posting_lists[tid].docid_postings.append(docid)
115
+ posting_lists[tid].tweight_postings.append(tf)
116
+ if tid < len(dfs):
117
+ dfs[tid] += 1
118
+ else:
119
+ dfs.append(0)
120
+ if store_raw:
121
+ doc_texts.append(doc.text)
122
+ else:
123
+ doc_texts = None
124
+ return Counting(
125
+ posting_lists=posting_lists,
126
+ vocab=vocab,
127
+ cid2docid=cid2docid,
128
+ collection_ids=collection_ids,
129
+ dfs=dfs,
130
+ dls=dls,
131
+ avgdl=sum(dls) / len(dls),
132
+ nterms=nterms,
133
+ doc_texts=doc_texts,
134
+ )
135
+
136
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
137
+ sciq = load_sciq()
138
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
139
+
140
+ """### BM25 Index"""
141
+
142
+ from __future__ import annotations
143
+ from dataclasses import asdict, dataclass
144
+ import math
145
+ import os
146
+ from typing import Iterable, List, Optional, Type
147
+ import tqdm
148
+ from nlp4web_codebase.ir.data_loaders.dm import Document
149
+
150
+
151
+ @dataclass
152
+ class BM25Index(InvertedIndex):
153
+
154
+ @staticmethod
155
+ def tokenize(text: str) -> List[str]:
156
+ return simple_tokenize(text)
157
+
158
+ @staticmethod
159
+ def cache_term_weights(
160
+ posting_lists: List[PostingList],
161
+ total_docs: int,
162
+ avgdl: float,
163
+ dfs: List[int],
164
+ dls: List[int],
165
+ k1: float,
166
+ b: float,
167
+ ) -> None:
168
+ """Compute term weights and caching"""
169
+
170
+ N = total_docs
171
+ for tid, posting_list in enumerate(
172
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
173
+ ):
174
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
175
+ for i in range(len(posting_list.docid_postings)):
176
+ docid = posting_list.docid_postings[i]
177
+ tf = posting_list.tweight_postings[i]
178
+ dl = dls[docid]
179
+ regularized_tf = BM25Index.calc_regularized_tf(
180
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
181
+ )
182
+ posting_list.tweight_postings[i] = regularized_tf * idf
183
+
184
+ @staticmethod
185
+ def calc_regularized_tf(
186
+ tf: int, dl: float, avgdl: float, k1: float, b: float
187
+ ) -> float:
188
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
189
+
190
+ @staticmethod
191
+ def calc_idf(df: int, N: int):
192
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
193
+
194
+ @classmethod
195
+ def build_from_documents(
196
+ cls: Type[BM25Index],
197
+ documents: Iterable[Document],
198
+ store_raw: bool = True,
199
+ output_dir: Optional[str] = None,
200
+ ndocs: Optional[int] = None,
201
+ show_progress_bar: bool = True,
202
+ k1: float = 0.9,
203
+ b: float = 0.4,
204
+ ) -> BM25Index:
205
+ # Counting TFs, DFs, doc_lengths, etc.:
206
+ counting = run_counting(
207
+ documents=documents,
208
+ tokenize_fn=BM25Index.tokenize,
209
+ store_raw=store_raw,
210
+ ndocs=ndocs,
211
+ show_progress_bar=show_progress_bar,
212
+ )
213
+
214
+ # Compute term weights and caching:
215
+ posting_lists = counting.posting_lists
216
+ total_docs = len(counting.cid2docid)
217
+ BM25Index.cache_term_weights(
218
+ posting_lists=posting_lists,
219
+ total_docs=total_docs,
220
+ avgdl=counting.avgdl,
221
+ dfs=counting.dfs,
222
+ dls=counting.dls,
223
+ k1=k1,
224
+ b=b,
225
+ )
226
+
227
+ # Assembly and save:
228
+ index = BM25Index(
229
+ posting_lists=posting_lists,
230
+ vocab=counting.vocab,
231
+ cid2docid=counting.cid2docid,
232
+ collection_ids=counting.collection_ids,
233
+ doc_texts=counting.doc_texts,
234
+ )
235
+ return index
236
+
237
+ bm25_index = BM25Index.build_from_documents(
238
+ documents=iter(sciq.corpus),
239
+ ndocs=12160,
240
+ show_progress_bar=True,
241
+ )
242
+ bm25_index.save("output/bm25_index")
243
+
244
+
245
+ """### BM25 Retriever"""
246
+
247
+ from nlp4web_codebase.ir.models import BaseRetriever
248
+ from typing import Type
249
+ from abc import abstractmethod
250
+
251
+
252
+ class BaseInvertedIndexRetriever(BaseRetriever):
253
+
254
+ @property
255
+ @abstractmethod
256
+ def index_class(self) -> Type[InvertedIndex]:
257
+ pass
258
+
259
+ def __init__(self, index_dir: str) -> None:
260
+ self.index = self.index_class.from_saved(index_dir)
261
+
262
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
263
+ toks = self.index.tokenize(query)
264
+ target_docid = self.index.cid2docid[cid]
265
+ term_weights = {}
266
+ for tok in toks:
267
+ if tok not in self.index.vocab:
268
+ continue
269
+ tid = self.index.vocab[tok]
270
+ posting_list = self.index.posting_lists[tid]
271
+ for docid, tweight in zip(
272
+ posting_list.docid_postings, posting_list.tweight_postings
273
+ ):
274
+ if docid == target_docid:
275
+ term_weights[tok] = tweight
276
+ break
277
+ return term_weights
278
+
279
+ def score(self, query: str, cid: str) -> float:
280
+ return sum(self.get_term_weights(query=query, cid=cid).values())
281
+
282
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
283
+ toks = self.index.tokenize(query)
284
+ docid2score: Dict[int, float] = {}
285
+ for tok in toks:
286
+ if tok not in self.index.vocab:
287
+ continue
288
+ tid = self.index.vocab[tok]
289
+ posting_list = self.index.posting_lists[tid]
290
+ for docid, tweight in zip(
291
+ posting_list.docid_postings, posting_list.tweight_postings
292
+ ):
293
+ docid2score.setdefault(docid, 0)
294
+ docid2score[docid] += tweight
295
+ docid2score = dict(
296
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
297
+ )
298
+ return {
299
+ self.index.collection_ids[docid]: score
300
+ for docid, score in docid2score.items()
301
+ }
302
+
303
+
304
+ class BM25Retriever(BaseInvertedIndexRetriever):
305
+
306
+ @property
307
+ def index_class(self) -> Type[BM25Index]:
308
+ return BM25Index
309
+
310
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
311
+ bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
312
+
313
+ class Hit(TypedDict):
314
+ cid: str
315
+ score: float
316
+ text: str
317
+
318
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
319
+ return_type = List[Hit]
320
+
321
+ ## YOUR_CODE_STARTS_HERE
322
+ bm25_index = BM25Index.build_from_documents(
323
+ documents=iter(sciq.corpus),
324
+ ndocs=12160,
325
+ show_progress_bar=True
326
+ )
327
+ bm25_index.save("output/bm25_index")
328
+
329
+ def search(query: str) -> List[Hit]:
330
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
331
+ result = bm25_retriever.retrieve(query)
332
+
333
+ l : return_type = []
334
+ for cid, score in result.items():
335
+ docid = bm25_retriever.index.cid2docid[cid]
336
+ text = bm25_retriever.index.doc_texts[docid]
337
+
338
+ l.append(Hit(cid=cid, score=score, text=text))
339
+
340
+ return l
341
+
342
+ demo = gr.Interface(
343
+ fn=search,
344
+ inputs=["text"],
345
+ outputs=["text"],
346
+ )
347
+ ## YOUR_CODE_ENDS_HERE
348
+ demo.launch()