NVLM-D-72B / eval /mmmu_utils.py
boxinw@nvidia.com
Add benchmark evaluation scripts
b925209
raw
history blame
24.6 kB
# Adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/utils/data_utils.py
"""Utils for data load, save, and process (e.g., prompt construction)"""
import os
import json
import yaml
import re
DOMAIN_CAT2SUB_CAT = {
'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
'Pharmacy', 'Public_Health'],
'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
}
CAT_SHORT2LONG = {
'acc': 'Accounting',
'agri': 'Agriculture',
'arch': 'Architecture_and_Engineering',
'art': 'Art',
'art_theory': 'Art_Theory',
'bas_med': 'Basic_Medical_Science',
'bio': 'Biology',
'chem': 'Chemistry',
'cli_med': 'Clinical_Medicine',
'cs': 'Computer_Science',
'design': 'Design',
'diag_med': 'Diagnostics_and_Laboratory_Medicine',
'econ': 'Economics',
'elec': 'Electronics',
'ep': 'Energy_and_Power',
'fin': 'Finance',
'geo': 'Geography',
'his': 'History',
'liter': 'Literature',
'manage': 'Manage',
'mark': 'Marketing',
'mate': 'Materials',
'math': 'Math',
'mech': 'Mechanical_Engineering',
'music': 'Music',
'phar': 'Pharmacy',
'phys': 'Physics',
'psy': 'Psychology',
'pub_health': 'Public_Health',
'socio': 'Sociology'
}
# DATA SAVING
def save_json(filename, ds):
with open(filename, 'w') as f:
json.dump(ds, f, indent=4)
def get_multi_choice_info(options):
"""
Given the list of options for multiple choice question
Return the index2ans and all_choices
"""
start_chr = 'A'
all_choices = []
index2ans = {}
for i, option in enumerate(options):
index2ans[chr(ord(start_chr) + i)] = option
all_choices.append(chr(ord(start_chr) + i))
return index2ans, all_choices
def load_yaml(file_path):
with open(file_path, 'r') as stream:
try:
yaml_dict = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
return yaml_dict
def parse_img_path(text):
matches = re.findall("<img='(.*?)'>", text)
return matches
def process_single_sample(data):
question = data['question']
o_imgs_paths = []
for option in data['options']:
current_o_imgs_paths = parse_img_path(option)
for img_path in current_o_imgs_paths:
o_imgs_paths.append(img_path)
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': None, 'question_type': data['question_type'], 'subfield': data['subfield']}
else:
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': data['image_1'], 'question_type': data['question_type'], 'subfield': data['subfield']}
def process_single_sample_pro(data):
question = data['question']
o_imgs_paths = []
for option in data['options']:
current_o_imgs_paths = parse_img_path(option)
for img_path in current_o_imgs_paths:
o_imgs_paths.append(img_path)
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': None, 'question_type': 'multiple-choice', 'subfield': data['subject']}
else:
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
'image': data['image_1'], 'question_type': 'multiple-choice', 'subfield': data['subject']}
# DATA SAVING
def save_json(filename, ds):
with open(filename, 'w') as f:
json.dump(ds, f, indent=4)
def save_jsonl(filename, data):
"""
Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
Args:
filename (str): The path to the file where the data should be saved.
data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
"""
with open(filename, 'w', encoding='utf-8') as f:
for img_path, caption in data.items():
# Extract the base filename without the extension
base_filename = os.path.basename(img_path)
# Create a JSON object with the filename as the key and caption as the value
json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
# Write the JSON object to the file, one per line
f.write(json_record + '\n')
def save_args(args, path_dir):
argsDict = args.__dict__
with open(path_dir + 'setting.txt', 'w') as f:
f.writelines('------------------ start ------------------' + '\n')
for eachArg, value in argsDict.items():
f.writelines(eachArg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------')
# DATA PROCESSING
def construct_prompt(sample, config):
question = sample['question'].strip()
# for i in range(8):
# question = question.replace(f" <image {i}> ", " ")
# question = question.replace(f" <image {i}>", " ")
# question = question.replace(f"<image {i}> ", " ")
# question = question.replace(f"<image {i}>", " ")
# question = question.strip()
options = eval(sample['options'])
example = ""
if sample['question_type'] == 'multiple-choice':
start_chr = 'A'
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
# example += f"{start_chr}. {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
# example = example.rstrip()
empty_prompt_sample_structure = config['multi_choice_example_format']
empty_prompt = empty_prompt_sample_structure.format(question, example)
res_dict = {'type': 'multichoice'}
res_dict['index2ans'] = index2ans
res_dict['correct_choice'] = sample['answer']
res_dict['all_choices'] = prediction_range
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
else:
empty_prompt_sample_structure = config['short_ans_example_format']
empty_prompt = empty_prompt_sample_structure.format(question)
res_dict = {'type': 'open'}
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = sample['answer']
res_dict.update(sample)
return res_dict
def construct_prompt_pro(sample, config):
question = sample['question'].strip()
# for i in range(8):
# question = question.replace(f" <image {i}> ", " ")
# question = question.replace(f" <image {i}>", " ")
# question = question.replace(f"<image {i}> ", " ")
# question = question.replace(f"<image {i}>", " ")
# question = question.strip()
options = eval(sample['options'])
if len(options) == 1:
print("This is wrongly formated. We correct to options[0].")
options = options[0]
example = ""
if sample['question_type'] == 'multiple-choice':
start_chr = 'A'
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
# example += f"{start_chr}. {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
# example = example.rstrip()
empty_prompt_sample_structure = config['multi_choice_example_format']
empty_prompt = empty_prompt_sample_structure.format(question, example)
res_dict = {'type': 'multichoice'}
res_dict['index2ans'] = index2ans
res_dict['correct_choice'] = sample['answer']
res_dict['all_choices'] = prediction_range
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
else:
empty_prompt_sample_structure = config['short_ans_example_format']
empty_prompt = empty_prompt_sample_structure.format(question)
res_dict = {'type': 'open'}
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = sample['answer']
res_dict.update(sample)
return res_dict
"""Response Parsing and Evaluation for various models"""
from typing import Dict
import re
import random
import numpy as np
# ----------- Process Multi-choice -------------
def parse_multi_choice_response(response, all_choices, index2ans):
"""
Parse the prediction from the generated response.
Return the predicted index e.g., A, B, C, D.
"""
for char in [',', '.', '!', '?', ';', ':', "'"]:
response = response.strip(char)
response = " " + response + " " # add space to avoid partial match
index_ans = True
ans_with_brack = False
candidates = []
for choice in all_choices: # e.g., (A) (B) (C) (D) A) B) C) D)
if f'({choice})' in response or f'{choice})' in response:
candidates.append(choice)
ans_with_brack = True
if len(candidates) == 0:
for choice in all_choices: # e.g., A B C D
if f' {choice} ' in response:
candidates.append(choice)
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
if len(candidates) == 0 and len(response.split()) > 5:
for index, ans in index2ans.items():
if ans.lower() in response.lower():
candidates.append(index)
index_ans = False # it's content ans.
if len(candidates) == 0: # still not get answer, randomly choose one.
pred_index = all_choices[0]
elif len(candidates) > 1:
start_indexes = []
if index_ans:
if ans_with_brack:
for can in candidates:
index = response.rfind(f'({can})')
start_indexes.append(index) # -1 will be ignored anyway
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
else:
for can in candidates:
index = response.rfind(f" {can} ")
start_indexes.append(index)
else:
for can in candidates:
index = response.lower().rfind(index2ans[can].lower())
start_indexes.append(index)
# get the last one
pred_index = candidates[np.argmax(start_indexes)]
else: # if only one candidate, use it.
pred_index = candidates[0]
return pred_index
# ----------- Process Open -------------
def check_is_number(string):
"""
Check if the given string a number.
"""
try:
float(string.replace(',', ''))
return True
except ValueError:
# check if there's comma inside
return False
def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(',', '')
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]
def extract_numbers(string):
"""
Exact all forms of numbers from a string with regex.
"""
# Pattern for numbers with commas
pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
# Pattern for scientific notation
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
# Pattern for simple numbers without commas
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
indicators_of_keys = ['could be ', 'so ', 'is ',
'thus ', 'therefore ', 'final ', 'answer ', 'result ']
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(['='])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
# key_responses.append(resp.split(indicator)[1].strip())
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
# pdb.set_trace()
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
# ----------- Evaluation -------------
def eval_multi_choice(gold_i, pred_i):
"""
Evaluate a multiple choice instance.
"""
correct = False
# only they are exactly the same, we consider it as correct
if isinstance(gold_i, list):
for answer in gold_i:
if answer == pred_i:
correct = True
break
else: # gold_i is a string
if gold_i == pred_i:
correct = True
return correct
def eval_open(gold_i, pred_i):
"""
Evaluate an open question instance
"""
correct = False
if isinstance(gold_i, list):
# use float to avoid trivial matches
norm_answers = []
for answer in gold_i:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(gold_i)
for pred in pred_i: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct = True
break
else: # it's a float number
if pred in norm_answers:
if not correct:
correct = True
break
return correct
# ----------- Batch Evaluation -------------
def evaluate(samples):
"""
Batch evaluation for multiple choice and open questions.
"""
pred_correct = 0
judge_dict = dict()
for sample in samples:
gold_i = sample['answer']
pred_i = sample['parsed_pred']
if sample['question_type'] == 'multiple-choice':
correct = eval_multi_choice(gold_i, pred_i)
else: # open question
correct = eval_open(gold_i, pred_i)
if correct:
judge_dict[sample['id']] = 'Correct'
pred_correct += 1
else:
judge_dict[sample['id']] = 'Wrong'
if len(samples) == 0:
return {'acc': 0}
return judge_dict, {'acc': pred_correct / len(samples)}
# ----------- Calculate Accuracy -------------
def calculate_ins_level_acc(results: Dict):
"""Calculate the instruction level accuracy for given Subject results"""
acc = 0
ins_num = 0
for cat_results in results.values():
acc += cat_results['acc'] * cat_results['num_example']
ins_num += cat_results['num_example']
if ins_num == 0:
return 0
return acc / ins_num
def mmmu_main_eval(output_dict, task_cfg):
answer_dict = json.load(open(task_cfg["answer_dict"]))
# group by category
output_dict_w_cat = {}
for data_id, parsed_pred in output_dict.items():
category = "_".join(data_id.split("_")[1:-1])
if category not in output_dict_w_cat:
output_dict_w_cat.update({category: {}})
output_dict_w_cat[category].update({data_id: parsed_pred})
# group by category
answer_dict_w_cat = {}
for data_id, parsed_pred in answer_dict.items():
category = "_".join(data_id.split("_")[1:-1])
if category not in answer_dict_w_cat:
answer_dict_w_cat.update({category: {}})
answer_dict_w_cat[category].update({data_id: parsed_pred})
evaluation_result = {}
for category in CAT_SHORT2LONG.values():
# print("Evaluating: {}".format(category))
# get cat_outputs and cat_answers
try:
cat_outputs = output_dict_w_cat[category]
cat_answers = answer_dict_w_cat[category]
except KeyError:
print("Skipping {} for not found".format(category))
continue
exampels_to_eval = []
for data_id, parsed_pred in cat_outputs.items():
question_type = cat_answers[data_id]['question_type']
if question_type != 'multiple-choice':
parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
else:
parsed_pred = parsed_pred
exampels_to_eval.append({
"id": data_id,
"question_type": question_type,
"answer": cat_answers[data_id]['ground_truth'],
"parsed_pred": parsed_pred
})
judge_dict, metric_dict = evaluate(exampels_to_eval)
metric_dict.update({"num_example": len(exampels_to_eval)})
evaluation_result[category] = metric_dict
printable_results = {}
# pdb.set_trace()
# add domain Subject
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
in_domain_cat_results = {}
for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
if cat_name in evaluation_result.keys():
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
else:
pass
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
"acc": round(in_domain_ins_acc, 4)
}
# add sub category
for cat_name, cat_results in in_domain_cat_results.items():
printable_results[cat_name] = {"num": int(cat_results['num_example']),
"acc": round(cat_results['acc'], 4)
}
# table.append(["-----------------------------", "-----", "----"])
all_ins_acc = calculate_ins_level_acc(evaluation_result)
printable_results['Overall'] = {
"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
"acc": round(all_ins_acc, 4)
}
# print(printable_results)
return printable_results
if __name__ == '__main__':
# tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi_oci.yaml"))['datasets']
tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi.yaml"))['datasets']
print(tasks)
# with open("/lustre/fs4/portfolios/adlr/users/boxinw/llava-megatron-gen/checkpoints/test/eval_mmmu_iter500_merged.4node.json") as f:
with open("/lustre/fsw/portfolios/llmservice/users/boxinw/eval_mmmu_iter6000_merged.0.53.json") as f:
merged_results = json.load(f)
eval_samples = []
eval_output_dict = {}
for res in merged_results:
pred_ans = res["answer"].upper()
gt_ans = res['gt_answer']
if res['question_type'] == 'multiple-choice':
parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
if pred_ans != parsed_pred:
print(f"MC: Original: {pred_ans}, Parsed: {parsed_pred}")
eval_samples.append(
{
'id': res['question_id'],
'question_type': res['question_type'],
'answer': res['gt_answer'], # the content in option, not answer index.
'response': pred_ans,
'parsed_pred': parsed_pred,
'index2ans': res['index2ans'],
}
)
eval_output_dict[res['question_id']] = parsed_pred
else:
parsed_pred = parse_open_response(pred_ans)
if pred_ans != parsed_pred:
print(f"Open: Original: {pred_ans}, Parsed: {parsed_pred}")
eval_samples.append(
{
'id': res['question_id'],
'question_type': res['question_type'],
'answer': res['gt_answer'],
'response': pred_ans,
'parsed_pred': parsed_pred,
}
)
eval_output_dict[res['question_id']] = pred_ans
json.dump(eval_output_dict, open("validation_mmmu_iter6000_merged.0.53.sorted.json", "w"), indent=4, sort_keys=True)
x = mmmu_main_eval(eval_output_dict,
task_cfg=tasks['mmmu'])
print(x)