import csv from typing import Any import gradio as gr import pandas as pd from sentence_transformers import SentenceTransformer, util from underthesea import word_tokenize from retriever_trainer import PretrainedColBERT bi_encoder = SentenceTransformer("phamson02/cotmae_biencoder2_170000_sbert") colbert = PretrainedColBERT( pretrained_model_name="phamson02/colbert2.1_290000", ) corpus_embeddings = pd.read_pickle("data/passage_embeds.pkl") with open("data/child_passages.tsv", "r") as f: tsv_reader = csv.reader(f, delimiter="\t") child_passage_ids, child_passages = zip(*[(row[0], row[1]) for row in tsv_reader]) with open("data/parent_passages.tsv", "r") as f: tsv_reader = csv.reader(f, delimiter="\t") parent_passages_map = {row[0]: row[1] for row in tsv_reader} def f7(seq): seen = set() seen_add = seen.add return [x for x in seq if not (x in seen or seen_add(x))] def search(query: str, reranking: bool = False, top_k: int = 100): query = word_tokenize(query, format="text") print("Top 5 Answer by the NSE:") print() ans: list[str] = [] ##### Sematic Search ##### # Encode the query using the bi-encoder and find potentially relevant passages question_embedding = bi_encoder.encode(query, convert_to_tensor=True) hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) hits = hits[0] # Get the hits for the first query top_k_child_passages = [child_passages[hit["corpus_id"]] for hit in hits][:20] top_k_child_passage_ids = [hit["corpus_id"] for hit in hits][:20] ##### Re-Ranking ##### # Now, score all retrieved passages with the cross_encoder if reranking: colbert_scores: list[dict[str, Any]] = colbert.rerank( query=query, documents=top_k_child_passages, top_k=20 ) # Reorder child passage ids based on the reranking top_k_child_passage_ids = [ top_k_child_passage_ids[score["corpus_id"]] for score in colbert_scores ] top_20_hits = top_k_child_passage_ids[0:20] hit_child_passage_ids = [child_passage_ids[id] for id in top_20_hits] hit_parent_passage_ids = f7( [ "_".join(hit_child_passage_id.split("_")[:-1]) for hit_child_passage_id in hit_child_passage_ids ] ) assert len(hit_parent_passage_ids) >= 5, "Not enough unique parent passages found" for hit in hit_parent_passage_ids[:5]: ans.append(parent_passages_map[hit]) return ans[0], ans[1], ans[2], ans[3], ans[4] exp = [ ["Who is steve jobs?", False], ["What is coldplay?", False], ["What is a turing test?", False], ["What is the most interesting thing about our universe?", False], ["What are the most beautiful places on earth?", False], ] desc = "This is a semantic search engine powered by SentenceTransformers (Nils_Reimers) with a retrieval and reranking system on Wikipedia corous. This will return the top 5 results. So Quest on with Transformers." inp = gr.Textbox(lines=1, placeholder=None, label="search you query here") reranking_checkbox = gr.Checkbox(label="Enable reranking") out1 = gr.Textbox(type="text", label="Search result 1") out2 = gr.Textbox(type="text", label="Search result 2") out3 = gr.Textbox(type="text", label="Search result 3") out4 = gr.Textbox(type="text", label="Search result 4") out5 = gr.Textbox(type="text", label="Search result 5") iface = gr.Interface( fn=search, inputs=[inp, reranking_checkbox], outputs=[out1, out2, out3, out4, out5], examples=exp, article=desc, title="Neural Search Engine", ) iface.launch()