Spaces:
Running
Running
from commonsenseConstraint import evaluation as commonsense_eval | |
from hardConstraint import evaluation as hard_eval | |
import json | |
from tqdm import tqdm | |
from datasets import load_dataset | |
def load_line_json_data(filename): | |
data = [] | |
with open(filename, 'r', encoding='utf-8') as f: | |
for line in f.read().strip().split('\n'): | |
unit = json.loads(line) | |
data.append(unit) | |
return data | |
def count_true_false(data): | |
"""Count the number of true and false values in a list.""" | |
true_count = data.count(True) | |
false_count = data.count(False) | |
return true_count, false_count | |
def statistics(commonsense_statistic): | |
"""Generate statistics for each level and day in the given data with a different structure.""" | |
result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic} | |
for level, days in commonsense_statistic.items(): | |
for day, dicts in days.items(): | |
for dct in dicts: | |
if dct: | |
for key, data in dct.items(): | |
true_count, false_count = count_true_false(data) | |
if key not in result[level][day]: | |
result[level][day][key] = {"true": 0, "false": 0} | |
result[level][day][key]["true"] += true_count | |
result[level][day][key]["false"] += false_count | |
return result | |
def eval_score(validation_or_test: str, file_path: str, TOKEN): | |
if validation_or_test == 'validation': | |
query_data_list = load_dataset('osunlp/TravelBenchEval','validation',token=TOKEN)['validation'] | |
elif validation_or_test == 'test': | |
query_data_list = load_dataset('osunlp/TravelBenchEval','test',token=TOKEN)['test'] | |
query_data_list = [x for x in query_data_list] | |
hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} | |
commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} | |
tested_plans = load_line_json_data(file_path) | |
delivery_cnt = 0 | |
plan_constraint_store = [] | |
for idx in tqdm(range(0,len(query_data_list))): | |
query_data = query_data_list[idx] | |
tested_plan = tested_plans[idx] | |
if type(query_data) == str: | |
query_data = eval(query_data) | |
if type(tested_plan) == str: | |
tested_plan = eval(tested_plan) | |
if type(query_data['local_constraint']) == str: | |
query_data['local_constraint'] = eval(query_data['local_constraint']) | |
if tested_plan['plan']: | |
delivery_cnt += 1 | |
commonsense_info_box = commonsense_eval(query_data,tested_plan['plan']) | |
else: | |
commonsense_info_box = None | |
if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]: | |
hard_info_box = hard_eval(query_data,tested_plan['plan']) | |
else: | |
hard_info_box = None | |
plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box}) | |
commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box) | |
hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box) | |
commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic) | |
hardConstraint_statistic_processed = statistics(hardConstraint_statistic) | |
# print(commonsenseConstraint_statistic_processed) | |
# print(hardConstraint_statistic_processed) | |
constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']} | |
constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'} | |
mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']} | |
count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']} | |
for unit in query_data_list: | |
count_record[unit['level']][unit['days']] += 1 | |
for key in constraint_record['medium'][3]: | |
if unit['local_constraint'][key] != None: | |
constraint_record[unit['level']][unit['days']][key] += 1 | |
mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1 | |
data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']} | |
constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}} | |
for constraint in ['commonsense','hard']: | |
if constraint == 'commonsense': | |
constraint_statistic = commonsenseConstraint_statistic_processed | |
elif constraint == 'hard': | |
constraint_statistic = hardConstraint_statistic_processed | |
key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']} | |
for key in constraint_statistic: | |
# level | |
for key2 in constraint_statistic[key]: | |
# day | |
# print(key2) | |
# key2 = eval(key2) | |
if key2 == -1: | |
print(constraint_statistic[key]) | |
exit(0) | |
for key3 in key_dict[constraint]: | |
data_record[key][key2].append('0/0') | |
if key3 in constraint_statistic[key][key2]: | |
constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true'] | |
if constraint == 'hard': | |
if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']: | |
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" | |
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] | |
elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']: | |
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" | |
constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] | |
else: | |
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" | |
if key3 in ['valid_cost','valid_visitng_city_number','valid_days']: | |
constraint_dis_record[constraint]['total'] += count_record[key][key2] | |
else: | |
data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" | |
constraint_dis_record[constraint]['total'] += count_record[key][key2] | |
final_all_cnt = 0 | |
final_commonsense_cnt = 0 | |
final_hardConstraint_cnt = 0 | |
final_all_cnt_map = {level:0 for level in ['easy','medium','hard']} | |
for idx in (range(0,len(query_data_list))): | |
if plan_constraint_store[idx]['commonsense_constraint']: | |
final_commonsense_pass = True | |
final_hardConstraint_pass = True | |
for item in plan_constraint_store[idx]['commonsense_constraint']: | |
if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]: | |
final_commonsense_pass = False | |
break | |
if plan_constraint_store[idx]['hard_constraint'] is None: | |
continue | |
for item in plan_constraint_store[idx]['hard_constraint']: | |
if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False: | |
final_hardConstraint_pass = False | |
break | |
if final_commonsense_pass: | |
final_commonsense_cnt += 1 | |
if final_hardConstraint_pass: | |
final_hardConstraint_cnt += 1 | |
if final_commonsense_pass and final_hardConstraint_pass: | |
final_all_cnt += 1 | |
final_all_cnt_map[query_data_list[idx]['level']] += 1 | |
result = {} | |
if validation_or_test == 'validation': | |
result['Delivery Rate'] = delivery_cnt / 180 | |
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440 | |
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180 | |
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420 | |
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180 | |
result['Final Pass Rate'] = final_all_cnt / 180 | |
elif validation_or_test == 'test': | |
result['Delivery Rate'] = delivery_cnt / 1000 | |
result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000 | |
result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000 | |
result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290 | |
result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000 | |
result['Final Pass Rate'] = final_all_cnt / 1000 | |
return result | |