Spaces:
Runtime error
Runtime error
import re | |
from rouge import Rouge | |
import argparse | |
import os | |
import json | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"] | |
image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"] | |
visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"] | |
visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"] | |
text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"] | |
multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"] | |
puzzle = ["RAVEN"] | |
nlrv2 = ["NLVR2_Mantis"] | |
qbench = ["QBench"] | |
class Eval: | |
def __init__(self): | |
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") | |
self.commaStrip = re.compile("(\d)(\,)(\d)") | |
self.punct = [ | |
";", | |
r"/", | |
"[", | |
"]", | |
'"', | |
"{", | |
"}", | |
"(", | |
")", | |
"=", | |
"+", | |
"\\", | |
"_", | |
"-", | |
">", | |
"<", | |
"@", | |
"`", | |
",", | |
"?", | |
"!", | |
] | |
def processPunctuation(self, inText): | |
outText = inText | |
for p in self.punct: | |
if (p + " " in inText or " " + p in inText) or ( | |
re.search(self.commaStrip, inText) != None | |
): | |
outText = outText.replace(p, "") | |
else: | |
outText = outText.replace(p, " ") | |
outText = self.periodStrip.sub("", outText, re.UNICODE) | |
return outText | |
def process(self, answer): | |
answer = answer.replace("\n", " ") | |
answer = answer.replace("\t", " ") | |
answer = answer.strip() | |
answer = self.processPunctuation(answer) | |
answer = answer.strip('\'') | |
answer = answer.strip('\"') | |
answer = answer.strip(')') | |
answer = answer.strip('(') | |
answer = answer.strip().lower() | |
return answer | |
def evaluate_rouge(self,preds): | |
rouge = Rouge() | |
acc = {'f': []} | |
eval_list = [] | |
for i, res in enumerate(preds): | |
sample_id = res['sample_id'] | |
# print(sample_id) | |
gt_ans = self.process(res["gt_response"]) | |
pred_ans = self.process(res["pred_response"]) | |
# assert gt_ans != '' | |
if gt_ans == '': | |
continue | |
if pred_ans == '': | |
s = 0 | |
else: | |
if len(pred_ans) > 512: | |
pred_ans = pred_ans[0: 512] | |
s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f'] | |
acc['f'].append(s) | |
eval_list.append({'id':str(sample_id),'score':str(round(s,3))}) | |
results = {'Rouge-L f': np.mean(acc['f'])} | |
return results,eval_list | |
def judge_multi_choice(self,sample): | |
sample_id = sample['sample_id'] | |
gt_ans = sample["gt_response"] | |
pred_ans = sample["pred_response"] | |
if ":" in pred_ans: | |
a_list = pred_ans.split(":") | |
a_list = [a.strip() for a in a_list ] | |
for a in a_list: | |
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: | |
pred_ans = a | |
if pred_ans == gt_ans: | |
return 1 | |
else: | |
return 0 | |
def process_sample(self,sample): | |
sample["gt_response"] = self.process(sample["gt_response"]) | |
sample["pred_response"] = self.process(sample["pred_response"]) | |
def evaluate_multichoice(self, preditions): | |
correct = 0 | |
eval_list = [] | |
for i, sample in enumerate(preditions): | |
self.process_sample(sample) | |
score = self.judge_multi_choice(sample) | |
sample_id = sample['sample_id'] | |
sample['result'] = score | |
eval_list.append({'id':str(sample_id),'score':str(score)}) | |
correct+=score | |
return {'Accuracy':correct/len(preditions)},eval_list | |
def evaluate_multi_choice_image(self,preditions): | |
correct = 0 | |
eval_list = [] | |
for i,sample in enumerate(preditions): | |
gt_ans = self.process(sample["gt_response"]) | |
pred_ans = self.process(sample["pred_response"]) | |
sample_id = sample['sample_id'] | |
if ":" in pred_ans: | |
a_list = pred_ans.split(":") | |
a_list = [a.strip() for a in a_list ] | |
for a in a_list: | |
if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: | |
pred_ans = a | |
if gt_ans == pred_ans: | |
score = 1 | |
else: | |
score = 0 | |
sample_id = sample['sample_id'] | |
sample['result'] = score | |
eval_list.append({'id':str(sample_id),'score':str(score)}) | |
correct+=score | |
return {'Accuracy':correct/len(preditions)},eval_list | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--result-dir', type=str, required=True) | |
args = parser.parse_args() | |
result_file = os.path.join(args.result_dir, "result.jsonl") | |
if not os.path.exists(result_file): | |
print('No prediction file found') | |
exit(0) | |
with open(result_file, 'r') as f: | |
preds_all = [json.loads(line) for line in f] | |
preds_all_dict = dict() | |
for pred in preds_all: | |
if pred["dataset"] not in preds_all_dict: | |
preds_all_dict[pred["dataset"]] = list() | |
preds_all_dict[pred["dataset"]].append(pred) | |
image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"] | |
E = Eval() | |
eval_result_list = dict() | |
eval_result_list_detail = dict() | |
for dataset in preds_all_dict: | |
preds = preds_all_dict[dataset] | |
question_type = preds[0]["question_type"] | |
if question_type == 'open-ended': | |
eval_result, eval_list = E.evaluate_rouge(preds) | |
elif question_type == 'multi-choice' or dataset == 'nlrv2': | |
if dataset in image_choice_dataset_list: | |
eval_result, eval_list = E.evaluate_multi_choice_image(preds) | |
else: | |
eval_result, eval_list = E.evaluate_multichoice(preds) | |
else: | |
eval_result = 'Dataset not supported' | |
print('Dataset not supported') | |
exit(0) | |
print(dataset, end = ': ') | |
print(eval_result) | |
eval_result_list[dataset] = eval_result | |
eval_result_list_detail[dataset] = eval_list | |
os.makedirs(args.result_dir, exist_ok=True) | |
with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f: | |
json.dump(eval_result_list, f, indent=4) | |
with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f: | |
json.dump(eval_result_list_detail, f, indent=4) | |
eval_cat_list = dict() | |
print() | |
# spot_the_diff | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in spot_the_diff: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["spot_the_diff"] = score | |
print("spot_the_diff", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# image_edit_instruct | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in image_edit_instruct: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["image_edit_instruct"] = score | |
print("image_edit_instruct", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# visual_story_telling | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in visual_story_telling: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["visual_story_telling"] = score | |
print("visual_story_telling", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# visual_cloze | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in visual_cloze: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["visual_cloze"] = score | |
print("visual_cloze", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# text_rich_vqa | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in text_rich_vqa: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["text_rich_vqa"] = score | |
print("text_rich_vqa", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# multi_image_vqa | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in multi_image_vqa: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["multi_image_vqa"] = score | |
print("multi_image_vqa", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# puzzle | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in puzzle: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["puzzle"] = score | |
print("puzzle", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# nlrv2 | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in nlrv2: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["nlrv2"] = score | |
print("nlrv2", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
# qbench | |
score = 0 | |
count = 0 | |
for dataset in eval_result_list: | |
if dataset in qbench: | |
count += 1 | |
score += list(eval_result_list[dataset].values())[0] | |
if count > 0: | |
score /= count | |
eval_cat_list["qbench"] = score | |
print("qbench", end = ': ') | |
print('{:.2f}'.format(100 * score)) | |
with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f: | |
json.dump(eval_cat_list, f, indent=4) |