import plotly.colors | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import os | |
import matplotlib.pyplot as plt | |
import argparse | |
from utils.score_extract.ood_agg import ood_t2i_agg, ood_i2t_agg | |
def to_rgba(rgb, alpha=1): | |
return 'rgba' + rgb[3:][:-1] + f', {alpha})' | |
def radar_plot(results, thetas, selected_models): | |
# Extract performance values for each model across all benchmarks | |
model_performance = {} | |
selected_models = [os.path.basename(model) for model in selected_models] | |
for model in selected_models: | |
if model in results: | |
benchmarks_data = results[model] | |
model_performance[model] = [benchmarks_data[subfield] for subfield in benchmarks_data.keys()] | |
# Create radar chart with plotly | |
fig = make_subplots( | |
rows=2, cols=1, | |
shared_xaxes=True, | |
vertical_spacing=0.2, | |
row_heights=[1, 0.4], | |
specs=[[{"type": "polar"}], [{"type": "table"}]] | |
) | |
for i, (model, performance) in enumerate(model_performance.items()): | |
fig.add_trace( | |
go.Scatterpolar( | |
r=performance + [performance[0]], | |
theta=thetas + [thetas[0]], | |
fill='toself', | |
connectgaps=True, | |
fillcolor=to_rgba(color, 0.1), | |
name=model.split('/')[-1], # Use the last part of the model name for clarity | |
), | |
row=1, col=1 | |
) | |
header_texts = ["Model"] + [x.replace("<br>", " ") for x in thetas] | |
rows = [[x.split('/')[-1] for x in selected_models]] + [[round(score[i], 2) for score in [model_performance[x] for x in selected_models]] for i in range(len(thetas))] | |
# column_widths = [len(x) for x in header_texts] | |
# column_widths[0] *= len(thetas) | |
fig.add_trace( | |
go.Table( | |
header=dict(values=header_texts, font=dict(size=14.5), align="left"), | |
cells=dict( | |
values=rows, | |
align="left", | |
font=dict(size=14.5), | |
height=30 | |
), | |
# columnwidth=column_widths | |
), | |
row=2, col=1 | |
) | |
fig.update_layout( | |
height=900, | |
legend=dict(font=dict(size=20), orientation="h", xanchor="center", x=0.5, y=0.35), | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 100], # Assuming accuracy is a percentage between 0 and 100 | |
tickfont=dict(size=12) | |
), | |
angularaxis=dict(tickfont=dict(size=20), type="category") | |
), | |
showlegend=True, | |
# title=f"{title}" | |
) | |
return fig | |
def main_radar_plot(main_scores, selected_models): | |
fig = make_subplots( | |
rows=2, cols=1, | |
shared_xaxes=True, | |
vertical_spacing=0.2, | |
row_heights=[1.0, 0.5], | |
specs=[[{"type": "polar"}], [{"type": "table"}]] | |
) | |
model_scores = {} | |
for model in selected_models: | |
model_name = os.path.basename(model) | |
model_scores[model_name] = main_scores[model_name] | |
perspectives = list(model_scores[os.path.basename(selected_models[0])].keys()) | |
perspectives_shift = perspectives | |
for i, model_name in enumerate(model_scores.keys()): | |
score_shifted = list(model_scores[model_name].values()) | |
fig.add_trace( | |
go.Scatterpolar( | |
r=score_shifted + [score_shifted[0]], | |
theta=perspectives_shift + [perspectives_shift[0]], | |
connectgaps=True, | |
fill='toself', | |
fillcolor=to_rgba(color, 0.1), | |
name=model_name, # Use the last part of the model name for clarity | |
), | |
row=1, col=1 | |
) | |
header_texts = ["Model"] + perspectives | |
rows = [ | |
list(model_scores.keys()), # Model Names | |
*[[round(score[perspective], 2) for score in list(model_scores.values())] for perspective in perspectives] | |
] | |
column_widths = [10] + [5] * len(perspectives) | |
fig.add_trace( | |
go.Table( | |
header=dict(values=header_texts, font=dict(size=14.5), align="left"), | |
cells=dict( | |
values=rows, | |
align="left", | |
font=dict(size=14.5), | |
height=30, | |
), | |
columnwidth=column_widths, | |
), | |
row=2, col=1 | |
) | |
fig.update_layout( | |
height=1200, | |
legend=dict(font=dict(size=20), orientation="h", xanchor="center", x=0.5, y=0.4), | |
polar=dict( | |
radialaxis=dict( | |
visible=True, | |
range=[0, 100], # Assuming accuracy is a percentage between 0 and 100 | |
tickfont=dict(size=12) | |
), | |
angularaxis=dict(tickfont=dict(size=20), type="category", rotation=5) | |
), | |
showlegend=True, | |
title=dict(text="MM-DecodingTrust Scores (Higher is Better)"), | |
) | |
return fig | |
def breakdown_plot(scenario_results, subfields, selected_models): | |
fig = radar_plot(scenario_results, subfields, selected_models) | |
return fig | |
def update_subscores(target_model, main_scores, config_dicts): | |
perspectives = [] | |
target_model = target_model.split('/')[-1] | |
curr_main_scores = {} | |
curr_main_scores[target_model] = {} | |
for perspective in main_scores[target_model].keys(): | |
curr_main_scores[target_model][config_dicts[perspective]["name"]] = main_scores[target_model][perspective] | |
perspectives.append(config_dicts[perspective]["name"]) | |
return curr_main_scores | |
def generate_plot(model, main_scores, sub_scores, config_dict, out_path="plots"): | |
curr_main_scores = update_subscores(model, main_scores, config_dict) | |
for idx, perspective in enumerate(config_dict.keys()): | |
if config_dict[perspective]["sub_plot"] == False: | |
continue | |
# if "openai/gpt-4-0314" not in sub_scores[perspective].keys(): | |
# model_list = [model] | |
# else: | |
# model_list = [model, "openai/gpt-4-0314"] | |
model_list = [model] | |
subplot = breakdown_plot(sub_scores[perspective], list(sub_scores[perspective][model].keys()), model_list) | |
perspective_name = config_dict[perspective]["name"].replace(" ", "_") | |
subplot.write_image(f"{out_path}/{perspective_name}_breakdown.png", width=1400, height=700) | |
plot = main_radar_plot(curr_main_scores, [model]) | |
plot.write_image(f"{out_path}/main.png", width=1400, height=700) | |
def generate_main_plot(models, main_scores): | |
curr_main_scores = main_scores | |
plot = main_radar_plot(curr_main_scores, models) | |
return plot | |
# plot.write_image(f"{out_path}/main.png", width=1400, height=700) | |
def generate_sub_plot(models, sub_scores, perspective): | |
subplot = breakdown_plot(sub_scores[perspective], list(sub_scores[perspective][models[0]].keys()), models) | |
return subplot | |
if __name__ == "__main__": | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument("--model", type=str, default="hf/meta-llama/Llama-2-7b-chat-hf") | |
# args = parser.parse_args() | |
t2i_models = [ # Average time spent running the following example | |
"dall-e-2", | |
"dall-e-3", | |
"DeepFloyd/IF-I-M-v1.0", # 15.372 | |
"dreamlike-art/dreamlike-photoreal-2.0", # 3.526 | |
"prompthero/openjourney-v4", # 4.981 | |
"stabilityai/stable-diffusion-xl-base-1.0", # 7.463 | |
] | |
i2t_models = [ # Average time spent running the following example | |
"gpt-4-vision-preview", | |
"gpt-4o-2024-05-13", | |
"llava-hf/llava-v1.6-vicuna-7b-hf" | |
] | |
perspectives = ["Safety", "Fairness", "Hallucination", "Privacy", "Adv", "OOD"] | |
main_scores_t2i = {} | |
main_scores_i2t = {} | |
sub_scores_t2i = {} | |
sub_scores_i2t = {} | |
for model in t2i_models: | |
model = model.split("/")[-1] | |
main_scores_t2i[model] = {} | |
for perspective in perspectives: | |
# Place holder | |
main_scores_t2i[model][perspective] = ood_t2i_agg(model, "./data/results")["score"] | |
if perspective not in sub_scores_t2i.keys(): | |
sub_scores_t2i[perspective] = {} | |
sub_scores_t2i[perspective][model] = ood_t2i_agg(model, "./data/results")["subscenarios"] | |
for model in i2t_models: | |
model = model.split("/")[-1] | |
main_scores_i2t[model] = {} | |
for perspective in perspectives: | |
# Place holder | |
main_scores_i2t[model][perspective] = ood_i2t_agg(model, "./data/results")["score"] | |
if perspective not in sub_scores_i2t.keys(): | |
sub_scores_i2t[perspective] = {} | |
sub_scores_i2t[perspective][model] = ood_i2t_agg(model, "./data/results")["subscenarios"] | |
# generate_main_plot(t2i_models, main_scores_t2i) | |
# generate_main_plot(i2t_models, main_scores_i2t) | |
generate_sub_plot(t2i_models, sub_scores_t2i, "ood") | |
# generate_sub_plot(i2t_models, sub_scores_i2t) | |