|
""" |
|
This script plots the result of the human evaluation on Amazon Mechanical Turk, where |
|
human participants chose between an image from ClimateGAN or from a different method. |
|
""" |
|
print("Imports...", end="") |
|
from argparse import ArgumentParser |
|
import os |
|
import yaml |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
comparables_dict = { |
|
"munit_flooded": "MUNIT", |
|
"cyclegan": "CycleGAN", |
|
"instagan": "InstaGAN", |
|
"instagan_copypaste": "Mask-InstaGAN", |
|
"painted_ground": "Painted ground", |
|
} |
|
|
|
|
|
|
|
palette_colorblind = sns.color_palette("colorblind") |
|
color_climategan = palette_colorblind[9] |
|
|
|
palette_colorblind = sns.color_palette("colorblind") |
|
color_munit = palette_colorblind[1] |
|
color_cyclegan = palette_colorblind[2] |
|
color_instagan = palette_colorblind[3] |
|
color_maskinstagan = palette_colorblind[6] |
|
color_paintedground = palette_colorblind[8] |
|
palette_comparables = [ |
|
color_munit, |
|
color_cyclegan, |
|
color_instagan, |
|
color_maskinstagan, |
|
color_paintedground, |
|
] |
|
palette_comparables_light = [ |
|
sns.light_palette(color, n_colors=3)[1] for color in palette_comparables |
|
] |
|
|
|
|
|
def parsed_args(): |
|
""" |
|
Parse and returns command-line args |
|
|
|
Returns: |
|
argparse.Namespace: the parsed arguments |
|
""" |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_csv", |
|
default="amt_omni-vs-other.csv", |
|
type=str, |
|
help="CSV containing the results of the human evaluation, pre-processed", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
default=None, |
|
type=str, |
|
help="Output directory", |
|
) |
|
parser.add_argument( |
|
"--dpi", |
|
default=200, |
|
type=int, |
|
help="DPI for the output images", |
|
) |
|
parser.add_argument( |
|
"--n_bs", |
|
default=1e6, |
|
type=int, |
|
help="Number of bootrstrap samples", |
|
) |
|
parser.add_argument( |
|
"--bs_seed", |
|
default=17, |
|
type=int, |
|
help="Bootstrap random seed, for reproducibility", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
args = parsed_args() |
|
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = Path(os.environ["SLURM_TMPDIR"]) |
|
else: |
|
output_dir = Path(args.output_dir) |
|
if not output_dir.exists(): |
|
output_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
output_yml = output_dir / "args_human_evaluation.yml" |
|
with open(output_yml, "w") as f: |
|
yaml.dump(vars(args), f) |
|
|
|
|
|
df = pd.read_csv(args.input_csv) |
|
|
|
|
|
comparables = df.comparable.unique() |
|
is_climategan_sum = [ |
|
df.loc[df.comparable == c, "climategan"].sum() for c in comparables |
|
] |
|
comparables = comparables[np.argsort(is_climategan_sum)[::-1]] |
|
|
|
|
|
sns.set(style="whitegrid") |
|
plt.rcParams.update({"font.family": "serif"}) |
|
plt.rcParams.update( |
|
{ |
|
"font.serif": [ |
|
"Computer Modern Roman", |
|
"Times New Roman", |
|
"Utopia", |
|
"New Century Schoolbook", |
|
"Century Schoolbook L", |
|
"ITC Bookman", |
|
"Bookman", |
|
"Times", |
|
"Palatino", |
|
"Charter", |
|
"serif" "Bitstream Vera Serif", |
|
"DejaVu Serif", |
|
] |
|
} |
|
) |
|
fontsize = "medium" |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi) |
|
|
|
|
|
sns.barplot( |
|
data=df.loc[df.is_valid], |
|
x="is_valid", |
|
y="comparable", |
|
order=comparables, |
|
orient="h", |
|
label="comparable", |
|
palette=palette_comparables_light, |
|
ci=None, |
|
) |
|
|
|
|
|
sns.barplot( |
|
data=df.loc[df.is_valid], |
|
x="climategan", |
|
y="comparable", |
|
order=comparables, |
|
orient="h", |
|
label="climategan", |
|
color=color_climategan, |
|
ci=99, |
|
n_boot=args.n_bs, |
|
seed=args.bs_seed, |
|
errcolor="black", |
|
errwidth=1.5, |
|
capsize=0.1, |
|
) |
|
|
|
|
|
y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1) |
|
x = 0.5 * np.ones(y.shape[0]) |
|
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black") |
|
|
|
|
|
yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()] |
|
yticklabels_text = ax.set_yticklabels( |
|
yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96 |
|
) |
|
for ytl in yticklabels_text: |
|
ax.add_artist(ytl) |
|
|
|
|
|
ax.set_ylabel(ylabel="") |
|
|
|
|
|
xlim = [0.0, 1.1] |
|
xticks = np.arange(xlim[0], xlim[1], 0.1) |
|
ax.set(xticks=xticks) |
|
plt.setp(ax.get_xticklabels(), fontsize=fontsize) |
|
|
|
|
|
ax.set_xlabel(None) |
|
|
|
|
|
sns.despine(left=True, bottom=True) |
|
|
|
|
|
output_fig = output_dir / "human_evaluation_rate_climategan.png" |
|
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") |
|
|