Matsa-demo / semantic_retrieval.py
puneetm's picture
Upload folder using huggingface_hub
35d31f5 verified
import numpy as np
from bs4 import BeautifulSoup
from sklearn.preprocessing import minmax_scale
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
sbert = SentenceTransformer("all-MiniLM-L6-v2")
from llm_query_api import *
def get_row_embedding(html_table):
def get_row_elements(html_table):
tr_elements = []
soup = BeautifulSoup(html_table, 'html.parser')
tr_tags = soup.find_all('tr')
for t in tr_tags:
temp = " " + str(t.get('description'))
try:
tr_elements.append({'id':str(t.get('id')), 'text': temp})
except:
pass
return tr_elements
rows = get_row_elements(html_table)
all_elements = rows
sentences = []
element_ids = []
for i in range(len(all_elements)):
sentences.append(all_elements[i]['text'])
element_ids.append(all_elements[i]['id'])
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
return embeddings, element_ids
def get_col_embedding(html_table):
def get_column_elements(html_table):
th_elements = []
soup = BeautifulSoup(html_table, 'html.parser')
th_tags = soup.find_all('th')
for t in th_tags:
temp = " " + str(t.get('description'))
try:
th_elements.append({'id':str(t.get('id')), 'text': temp})
except:
pass
return th_elements
cols = get_column_elements(html_table)
all_elements = cols
sentences = []
element_ids = []
for i in range(len(all_elements)):
sentences.append(all_elements[i]['text'])
element_ids.append(all_elements[i]['id'])
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
return embeddings, element_ids
def normalize_list_numpy(list_numpy):
normalized_list = minmax_scale(list_numpy)
return normalized_list
def get_answer_embedding(answer):
return sbert.encode([answer], convert_to_tensor=True).cpu().numpy()
def row_attribution(answer, html_table, topk=5, threshold = 0.7):
answer_embedding = get_answer_embedding(answer)
row_embedding = get_row_embedding(html_table)
similarities = cosine_similarity(row_embedding[0], answer_embedding.reshape(1, -1))
sims = similarities.flatten()
sims = normalize_list_numpy(sims)
#if no of rows >= 5, take max of (5, 1/3 x rows)
#else if no of rows < 5, take least of (5, rows)
k = max(topk, int(0.3*len(sims)))
k = min(k, len(sims))
top_k_indices = np.argpartition(sims, -k)[-k:]
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
top_k_results = [row_embedding[1][idx] for idx in sorted_indices]
return top_k_results
def col_attribution(answer, html_table, topk=5, threshold = 0.7):
answer_embedding = get_answer_embedding(answer)
col_embedding = get_col_embedding(html_table)
similarities = cosine_similarity(col_embedding[0], answer_embedding.reshape(1, -1))
sims = similarities.flatten()
sims = normalize_list_numpy(sims)
#if no of cols >= 5, take max of (5, 1/3 x cols)
#else if no of cols < 5, take least of (5, cols)
k = max(topk, int(0.3*len(sims)))
k = min(k, len(sims))
top_k_indices = np.argpartition(sims, -k)[-k:]
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
top_k_results = [col_embedding[1][idx] for idx in sorted_indices]
return top_k_results
def retain_rows_and_columns(augmented_html_table, row_ids, column_ids):
soup = BeautifulSoup(augmented_html_table, 'html.parser')
row_ids = list(set(row_ids))
column_ids = list(set(column_ids))
# Retain specified rows and remove others
all_rows = soup.find_all('tr')
for row in all_rows:
if row.get('id') not in row_ids:
row.decompose()
# Retain specified columns and remove others
if all_rows:
all_columns = all_rows[0].find_all(['th'])
for i, col in enumerate(all_columns):
if col.get('id') not in column_ids:
for row in soup.find_all('tr'):
cells = row.find_all(['td', 'th'])
if len(cells) > i:
cells[i].decompose()
return str(soup)
def get_embedding_attribution(augmented_html_table, decomposed_fact_list, topK=5, threshold = 0.7):
row_attribution_ids = []
col_attribution_ids = []
for i in range(len(decomposed_fact_list)):
answer = decomposed_fact_list[i]
rorAttr = row_attribution(answer, augmented_html_table, topK)
colAttr = col_attribution(answer, augmented_html_table, topK)
row_attribution_ids.extend(rorAttr)
col_attribution_ids.extend(colAttr)
attributed_html_table = retain_rows_and_columns(augmented_html_table, row_attribution_ids, col_attribution_ids)
return attributed_html_table, row_attribution_ids, col_attribution_ids