Binder / utils /utils.py
Timothyxxx
Init
f6f97d8
"""
General utilities.
"""
import json
import os
from typing import List, Union, Dict
from functools import cmp_to_key
import math
from collections.abc import Iterable
from datasets import load_dataset
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../")
def _load_table(table_path) -> dict:
"""
attention: the table_path must be the .tsv path.
Load the WikiTableQuestion from csv file. Result in a dict format like:
{"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
"""
def __extract_content(_line: str):
_vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")]
return _vals
with open(table_path, "r") as f:
lines = f.readlines()
rows = []
for i, line in enumerate(lines):
line = line.strip('\n')
if i == 0:
header = line.split("\t")
else:
rows.append(__extract_content(line))
table_item = {"header": header, "rows": rows}
# Defense assertion
for i in range(len(rows) - 1):
if not len(rows[i]) == len(rows[i - 1]):
raise ValueError('some rows have diff cols.')
return table_item
def majority_vote(
nsqls: List,
pred_answer_list: List,
allow_none_and_empty_answer: bool = False,
allow_error_answer: bool = False,
answer_placeholder: Union[str, int] = '<error|empty>',
vote_method: str = 'prob',
answer_biased: Union[str, int] = None,
answer_biased_weight: float = None,
):
"""
Determine the final nsql execution answer by majority vote.
"""
def _compare_answer_vote_simple(a, b):
"""
First compare occur times. If equal, then compare max nsql logprob.
"""
if a[1]['count'] > b[1]['count']:
return 1
elif a[1]['count'] < b[1]['count']:
return -1
else:
if a[1]['nsqls'][0][1] > b[1]['nsqls'][0][1]:
return 1
elif a[1]['nsqls'][0][1] == b[1]['nsqls'][0][1]:
return 0
else:
return -1
def _compare_answer_vote_with_prob(a, b):
"""
Compare prob sum.
"""
return 1 if sum([math.exp(nsql[1]) for nsql in a[1]['nsqls']]) > sum(
[math.exp(nsql[1]) for nsql in b[1]['nsqls']]) else -1
# Vote answers
candi_answer_dict = dict()
for (nsql, logprob), pred_answer in zip(nsqls, pred_answer_list):
if allow_none_and_empty_answer:
if pred_answer == [None] or pred_answer == []:
pred_answer = [answer_placeholder]
if allow_error_answer:
if pred_answer == '<error>':
pred_answer = [answer_placeholder]
# Invalid execution results
if pred_answer == '<error>' or pred_answer == [None] or pred_answer == []:
continue
if candi_answer_dict.get(tuple(pred_answer), None) is None:
candi_answer_dict[tuple(pred_answer)] = {
'count': 0,
'nsqls': []
}
answer_info = candi_answer_dict.get(tuple(pred_answer), None)
answer_info['count'] += 1
answer_info['nsqls'].append([nsql, logprob])
# All candidates execution errors
if len(candi_answer_dict) == 0:
return answer_placeholder, [(nsqls[0][0], nsqls[0][-1])]
# Sort
if vote_method == 'simple':
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
elif vote_method == 'prob':
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_with_prob), reverse=True)
elif vote_method == 'answer_biased':
# Specifically for Tabfact entailed answer, i.e., `1`.
# If there exists nsql that produces `1`, we consider it more significant because `0` is very common.
assert answer_biased_weight is not None and answer_biased_weight > 0
for answer, answer_dict in candi_answer_dict.items():
if answer == (answer_biased,):
answer_dict['count'] *= answer_biased_weight
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
elif vote_method == 'lf_biased':
# Assign weights to different types of logic forms (lf) to control interpretability and coverage
for answer, answer_dict in candi_answer_dict.items():
count = 0
for nsql, _ in answer_dict['nsqls']:
if 'map@' in nsql:
count += 10
elif 'ans@' in nsql:
count += 10
else:
count += 1
answer_dict['count'] = count
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
else:
raise ValueError(f"Vote method {vote_method} is not supported.")
pred_answer_info = sorted_candi_answer_list[0]
pred_answer, pred_answer_nsqls = list(pred_answer_info[0]), pred_answer_info[1]['nsqls']
return pred_answer, pred_answer_nsqls
def load_data_split(dataset_to_load, split, data_dir=os.path.join(ROOT_DIR, 'datasets/')):
dataset_split_loaded = load_dataset(
path=os.path.join(data_dir, "{}.py".format(dataset_to_load)),
cache_dir=os.path.join(data_dir, "data"))[split]
# unify names of keys
if dataset_to_load in ['wikitq', 'has_squall', 'missing_squall',
'wikitq', 'wikitq_sql_solvable', 'wikitq_sql_unsolvable',
'wikitq_sql_unsolvable_but_in_squall',
'wikitq_scalability_ori',
'wikitq_scalability_100rows',
'wikitq_scalability_200rows',
'wikitq_scalability_500rows',
'wikitq_robustness'
]:
pass
elif dataset_to_load == 'tab_fact':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['question'] = data_item['statement']
data_item['answer_text'] = data_item['label']
data_item['table']['page_title'] = data_item['table']['caption']
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
elif dataset_to_load == 'hybridqa':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['table']['page_title'] = data_item['context'].split(' | ')[0]
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
elif dataset_to_load == 'mmqa':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['table']['page_title'] = data_item['table']['title']
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
else:
raise ValueError(f'{dataset_to_load} dataset is not supported now.')
return dataset_split_loaded
def pprint_dict(dic):
print(json.dumps(dic, indent=2))
def flatten(nested_list):
for x in nested_list:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from flatten(x)
else:
yield x