Spaces:
Runtime error
Runtime error
File size: 7,703 Bytes
f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""
General utilities.
"""
import json
import os
from typing import List, Union, Dict
from functools import cmp_to_key
import math
from collections.abc import Iterable
from datasets import load_dataset
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../")
def _load_table(table_path) -> dict:
"""
attention: the table_path must be the .tsv path.
Load the WikiTableQuestion from csv file. Result in a dict format like:
{"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]}
"""
def __extract_content(_line: str):
_vals = [_.replace("\n", " ").strip() for _ in _line.strip("\n").split("\t")]
return _vals
with open(table_path, "r") as f:
lines = f.readlines()
rows = []
for i, line in enumerate(lines):
line = line.strip('\n')
if i == 0:
header = line.split("\t")
else:
rows.append(__extract_content(line))
table_item = {"header": header, "rows": rows}
# Defense assertion
for i in range(len(rows) - 1):
if not len(rows[i]) == len(rows[i - 1]):
raise ValueError('some rows have diff cols.')
return table_item
def majority_vote(
nsqls: List,
pred_answer_list: List,
allow_none_and_empty_answer: bool = False,
allow_error_answer: bool = False,
answer_placeholder: Union[str, int] = '<error|empty>',
vote_method: str = 'prob',
answer_biased: Union[str, int] = None,
answer_biased_weight: float = None,
):
"""
Determine the final nsql execution answer by majority vote.
"""
def _compare_answer_vote_simple(a, b):
"""
First compare occur times. If equal, then compare max nsql logprob.
"""
if a[1]['count'] > b[1]['count']:
return 1
elif a[1]['count'] < b[1]['count']:
return -1
else:
if a[1]['nsqls'][0][1] > b[1]['nsqls'][0][1]:
return 1
elif a[1]['nsqls'][0][1] == b[1]['nsqls'][0][1]:
return 0
else:
return -1
def _compare_answer_vote_with_prob(a, b):
"""
Compare prob sum.
"""
return 1 if sum([math.exp(nsql[1]) for nsql in a[1]['nsqls']]) > sum(
[math.exp(nsql[1]) for nsql in b[1]['nsqls']]) else -1
# Vote answers
candi_answer_dict = dict()
for (nsql, logprob), pred_answer in zip(nsqls, pred_answer_list):
if allow_none_and_empty_answer:
if pred_answer == [None] or pred_answer == []:
pred_answer = [answer_placeholder]
if allow_error_answer:
if pred_answer == '<error>':
pred_answer = [answer_placeholder]
# Invalid execution results
if pred_answer == '<error>' or pred_answer == [None] or pred_answer == []:
continue
if candi_answer_dict.get(tuple(pred_answer), None) is None:
candi_answer_dict[tuple(pred_answer)] = {
'count': 0,
'nsqls': []
}
answer_info = candi_answer_dict.get(tuple(pred_answer), None)
answer_info['count'] += 1
answer_info['nsqls'].append([nsql, logprob])
# All candidates execution errors
if len(candi_answer_dict) == 0:
return answer_placeholder, [(nsqls[0][0], nsqls[0][-1])]
# Sort
if vote_method == 'simple':
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
elif vote_method == 'prob':
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_with_prob), reverse=True)
elif vote_method == 'answer_biased':
# Specifically for Tabfact entailed answer, i.e., `1`.
# If there exists nsql that produces `1`, we consider it more significant because `0` is very common.
assert answer_biased_weight is not None and answer_biased_weight > 0
for answer, answer_dict in candi_answer_dict.items():
if answer == (answer_biased,):
answer_dict['count'] *= answer_biased_weight
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
elif vote_method == 'lf_biased':
# Assign weights to different types of logic forms (lf) to control interpretability and coverage
for answer, answer_dict in candi_answer_dict.items():
count = 0
for nsql, _ in answer_dict['nsqls']:
if 'map@' in nsql:
count += 10
elif 'ans@' in nsql:
count += 10
else:
count += 1
answer_dict['count'] = count
sorted_candi_answer_list = sorted(list(candi_answer_dict.items()),
key=cmp_to_key(_compare_answer_vote_simple), reverse=True)
else:
raise ValueError(f"Vote method {vote_method} is not supported.")
pred_answer_info = sorted_candi_answer_list[0]
pred_answer, pred_answer_nsqls = list(pred_answer_info[0]), pred_answer_info[1]['nsqls']
return pred_answer, pred_answer_nsqls
def load_data_split(dataset_to_load, split, data_dir=os.path.join(ROOT_DIR, 'datasets/')):
dataset_split_loaded = load_dataset(
path=os.path.join(data_dir, "{}.py".format(dataset_to_load)),
cache_dir=os.path.join(data_dir, "data"))[split]
# unify names of keys
if dataset_to_load in ['wikitq', 'has_squall', 'missing_squall',
'wikitq', 'wikitq_sql_solvable', 'wikitq_sql_unsolvable',
'wikitq_sql_unsolvable_but_in_squall',
'wikitq_scalability_ori',
'wikitq_scalability_100rows',
'wikitq_scalability_200rows',
'wikitq_scalability_500rows',
'wikitq_robustness'
]:
pass
elif dataset_to_load == 'tab_fact':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['question'] = data_item['statement']
data_item['answer_text'] = data_item['label']
data_item['table']['page_title'] = data_item['table']['caption']
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
elif dataset_to_load == 'hybridqa':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['table']['page_title'] = data_item['context'].split(' | ')[0]
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
elif dataset_to_load == 'mmqa':
new_dataset_split_loaded = []
for data_item in dataset_split_loaded:
data_item['table']['page_title'] = data_item['table']['title']
new_dataset_split_loaded.append(data_item)
dataset_split_loaded = new_dataset_split_loaded
else:
raise ValueError(f'{dataset_to_load} dataset is not supported now.')
return dataset_split_loaded
def pprint_dict(dic):
print(json.dumps(dic, indent=2))
def flatten(nested_list):
for x in nested_list:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from flatten(x)
else:
yield x
|