|
""" |
|
This scripts plots examples of the images that get best and worse metrics |
|
""" |
|
print("Imports...", end="") |
|
import os |
|
import sys |
|
from argparse import ArgumentParser |
|
from pathlib import Path |
|
|
|
import matplotlib.patches as mpatches |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import seaborn as sns |
|
import yaml |
|
from imageio import imread |
|
from matplotlib.gridspec import GridSpec |
|
from skimage.color import rgba2rgb |
|
from sklearn.metrics.pairwise import euclidean_distances |
|
|
|
sys.path.append("../") |
|
|
|
from climategan.data import encode_mask_label |
|
from climategan.eval_metrics import edges_coherence_std_min |
|
from eval_masker import crop_and_resize |
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics = ["error", "f05", "edge_coherence"] |
|
|
|
dict_metrics = { |
|
"names": { |
|
"tpr": "TPR, Recall, Sensitivity", |
|
"tnr": "TNR, Specificity, Selectivity", |
|
"fpr": "FPR", |
|
"fpt": "False positives relative to image size", |
|
"fnr": "FNR, Miss rate", |
|
"fnt": "False negatives relative to image size", |
|
"mpr": "May positive rate (MPR)", |
|
"mnr": "May negative rate (MNR)", |
|
"accuracy": "Accuracy (ignoring may)", |
|
"error": "Error", |
|
"f05": "F05 score", |
|
"precision": "Precision", |
|
"edge_coherence": "Edge coherence", |
|
"accuracy_must_may": "Accuracy (ignoring cannot)", |
|
}, |
|
"key_metrics": ["error", "f05", "edge_coherence"], |
|
} |
|
|
|
|
|
|
|
colorblind_palette = sns.color_palette("colorblind") |
|
color_cannot = colorblind_palette[1] |
|
color_must = colorblind_palette[2] |
|
color_may = colorblind_palette[7] |
|
color_pred = colorblind_palette[4] |
|
|
|
icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5) |
|
color_tp = icefire[0] |
|
color_tn = icefire[1] |
|
color_fp = icefire[4] |
|
color_fn = icefire[3] |
|
|
|
|
|
def parsed_args(): |
|
""" |
|
Parse and returns command-line args |
|
|
|
Returns: |
|
argparse.Namespace: the parsed arguments |
|
""" |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--input_csv", |
|
default="ablations_metrics_20210311.csv", |
|
type=str, |
|
help="CSV containing the results of the ablation study", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
default=None, |
|
type=str, |
|
help="Output directory", |
|
) |
|
parser.add_argument( |
|
"--models_log_path", |
|
default=None, |
|
type=str, |
|
help="Path containing the log files of the models", |
|
) |
|
parser.add_argument( |
|
"--masker_test_set_dir", |
|
default=None, |
|
type=str, |
|
help="Directory containing the test images", |
|
) |
|
parser.add_argument( |
|
"--best_model", |
|
default="dada, msd_spade, pseudo", |
|
type=str, |
|
help="The string identifier of the best model", |
|
) |
|
parser.add_argument( |
|
"--dpi", |
|
default=200, |
|
type=int, |
|
help="DPI for the output images", |
|
) |
|
parser.add_argument( |
|
"--alpha", |
|
default=0.5, |
|
type=float, |
|
help="Transparency of labels shade", |
|
) |
|
parser.add_argument( |
|
"--percentile", |
|
default=0.05, |
|
type=float, |
|
help="Transparency of labels shade", |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
default=None, |
|
type=int, |
|
help="Bootstrap random seed, for reproducibility", |
|
) |
|
parser.add_argument( |
|
"--no_images", |
|
action="store_true", |
|
default=False, |
|
help="Do not generate images", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def map_color(arr, input_color, output_color, rtol=1e-09): |
|
""" |
|
Maps one color to another |
|
""" |
|
input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,))) |
|
output = arr.copy() |
|
output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color |
|
return output |
|
|
|
|
|
def plot_labels(ax, img, label, img_id, n_, add_title, do_legend): |
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot) |
|
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) |
|
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) |
|
|
|
ax.imshow(img) |
|
ax.imshow(label_colmap, alpha=0.5) |
|
ax.axis("off") |
|
|
|
if n_ in [1, 3, 5]: |
|
color_ = "green" |
|
else: |
|
color_ = "red" |
|
|
|
ax.text( |
|
-0.15, |
|
0.5, |
|
img_id, |
|
color=color_, |
|
fontweight="roman", |
|
fontsize="x-large", |
|
horizontalalignment="left", |
|
verticalalignment="center", |
|
transform=ax.transAxes, |
|
) |
|
|
|
if add_title: |
|
ax.set_title("Labels", rotation=0, fontsize="x-large") |
|
|
|
|
|
def plot_pred(ax, img, pred, img_id, add_title, do_legend): |
|
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) |
|
|
|
pred_colmap = pred.astype(float) |
|
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) |
|
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) |
|
pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma |
|
|
|
ax.imshow(img) |
|
ax.imshow(pred_colmap_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
if add_title: |
|
ax.set_title("Prediction", rotation=0, fontsize="x-large") |
|
|
|
|
|
def plot_correct_incorrect( |
|
ax, img_filename, img, metric, label, img_id, n_, add_title, do_legend |
|
): |
|
|
|
fp_map = imread( |
|
model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem) |
|
) |
|
fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3)) |
|
|
|
fp_map_colmap = fp_map.astype(float) |
|
fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp) |
|
|
|
|
|
fn_map = imread( |
|
model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem) |
|
) |
|
fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3)) |
|
|
|
fn_map_colmap = fn_map.astype(float) |
|
fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn) |
|
|
|
|
|
tp_map = imread( |
|
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem) |
|
) |
|
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) |
|
|
|
tp_map_colmap = tp_map.astype(float) |
|
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) |
|
|
|
|
|
tn_map = imread( |
|
model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem) |
|
) |
|
tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3)) |
|
|
|
tn_map_colmap = tn_map.astype(float) |
|
tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn) |
|
|
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (0, 0, 0), color_may) |
|
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may) |
|
label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma |
|
|
|
|
|
maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap |
|
maps_ma = np.ma.masked_equal(maps, (0, 0, 0)) |
|
maps_ma = maps_ma.mask * img + maps_ma |
|
|
|
ax.imshow(img) |
|
ax.imshow(label_colmap_ma, alpha=0.5) |
|
ax.imshow(maps_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
if add_title: |
|
ax.set_title("Metric", rotation=0, fontsize="x-large") |
|
|
|
|
|
def plot_edge_coherence(ax, img, metric, label, pred, img_id, n_, add_title, do_legend): |
|
pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3)) |
|
|
|
ec, pred_ec, label_ec = edges_coherence_std_min( |
|
np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood")) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_ec_coord = np.argwhere(pred_ec > 0) |
|
label_ec_coord = np.argwhere(label_ec > 0) |
|
|
|
|
|
dist_mat = np.divide( |
|
euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0] |
|
) |
|
|
|
|
|
min_dist = np.min(dist_mat, axis=1) |
|
|
|
|
|
|
|
|
|
|
|
pred_ec = np.tile( |
|
np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) |
|
) |
|
pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred) |
|
pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) |
|
|
|
label_ec = np.tile( |
|
np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3) |
|
) |
|
label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must) |
|
label_ec_colmap_ma = np.ma.masked_not_equal( |
|
label_ec_colmap, color_must |
|
) |
|
|
|
|
|
combined_ec = pred_ec_colmap + label_ec_colmap |
|
combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0)) |
|
combined_ec_img = combined_ec_ma.mask * img + combined_ec |
|
|
|
|
|
pred_colmap = pred.astype(float) |
|
pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred) |
|
pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred) |
|
|
|
|
|
label_colmap = label.astype(float) |
|
label_colmap = map_color(label_colmap, (0, 0, 255), color_must) |
|
label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must) |
|
|
|
|
|
tp_map = imread( |
|
model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem) |
|
) |
|
tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3)) |
|
tp_map_colmap = tp_map.astype(float) |
|
tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp) |
|
tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp) |
|
|
|
|
|
comb_pred = ( |
|
(pred_colmap_ma.mask ^ tp_map_colmap_ma.mask) |
|
& tp_map_colmap_ma.mask |
|
& combined_ec_ma.mask |
|
) * pred_colmap |
|
comb_label = ( |
|
(label_colmap_ma.mask ^ pred_colmap_ma.mask) |
|
& pred_colmap_ma.mask |
|
& combined_ec_ma.mask |
|
) * label_colmap |
|
comb_tp = combined_ec_ma.mask * tp_map_colmap.copy() |
|
combined = comb_tp + comb_label + comb_pred |
|
combined_ma = np.ma.masked_equal(combined, (0, 0, 0)) |
|
combined_ma = combined_ma.mask * combined_ec_img + combined_ma |
|
|
|
ax.imshow(combined_ec_img, alpha=1) |
|
ax.imshow(combined_ma, alpha=0.5) |
|
ax.axis("off") |
|
|
|
|
|
idx_sort_x = np.argsort(pred_ec_coord[:, 1]) |
|
offset = 100 |
|
for idx in range(offset, pred_ec_coord.shape[0], offset): |
|
y0, x0 = pred_ec_coord[idx_sort_x[idx], :] |
|
argmin = np.argmin(dist_mat[idx_sort_x[idx]]) |
|
y1, x1 = label_ec_coord[argmin, :] |
|
ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5) |
|
|
|
if add_title: |
|
ax.set_title("Metric", rotation=0, fontsize="x-large") |
|
|
|
|
|
def plot_images_metric( |
|
axes, metric, img_filename, img_id, n_, srs_sel, add_title, do_legend |
|
): |
|
|
|
|
|
img_path = imgs_orig_path / img_filename |
|
label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem) |
|
img, label = crop_and_resize(img_path, label_path) |
|
img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0 |
|
|
|
pred = imread( |
|
model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem) |
|
) |
|
|
|
|
|
plot_labels(axes[0], img, label, img_id, n_, add_title, do_legend) |
|
|
|
|
|
plot_pred(axes[1], img, pred, img_id, add_title, do_legend) |
|
|
|
|
|
if metric in ["error", "f05"]: |
|
plot_correct_incorrect( |
|
axes[2], |
|
img_filename, |
|
img, |
|
metric, |
|
label, |
|
img_id, |
|
n_, |
|
add_title, |
|
do_legend=False, |
|
) |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_may, |
|
label="May-be-flooded", |
|
linewidth=lw, |
|
alpha=0.66, |
|
) |
|
) |
|
labels = ["TP", "TN", "FP", "FN", "May-be-flooded"] |
|
if metric == "error": |
|
if n_ in [1, 3, 5]: |
|
title = "Low error rate" |
|
else: |
|
title = "High error rate" |
|
else: |
|
if n_ in [1, 3, 5]: |
|
title = "High F05 score" |
|
else: |
|
title = "Low F05 score" |
|
|
|
elif metric == "edge_coherence": |
|
plot_edge_coherence( |
|
axes[2], img, metric, label, pred, img_id, n_, add_title, do_legend=False |
|
) |
|
handles = [] |
|
lw = 1.0 |
|
handles.append( |
|
mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66) |
|
) |
|
handles.append( |
|
mpatches.Patch( |
|
facecolor=color_must, |
|
label="Must-be-flooded", |
|
linewidth=lw, |
|
alpha=0.66, |
|
) |
|
) |
|
labels = ["TP", "Prediction", "Must-be-flooded"] |
|
if n_ in [1, 3, 5]: |
|
title = "High edge coherence" |
|
else: |
|
title = "Low edge coherence" |
|
|
|
else: |
|
raise ValueError |
|
|
|
labels_values_title = "Error: {:.4f} \nFO5: {:.4f} \nEdge coherence: {:.4f}".format( |
|
srs_sel.error, srs_sel.f05, srs_sel.edge_coherence |
|
) |
|
|
|
plot_legend(axes[3], img, handles, labels, labels_values_title, title) |
|
|
|
|
|
def plot_legend(ax, img, handles, labels, labels_values_title, title): |
|
img_ = np.zeros_like(img, dtype=np.uint8) |
|
img_.fill(255) |
|
ax.imshow(img_) |
|
ax.axis("off") |
|
|
|
leg1 = ax.legend( |
|
handles=handles, |
|
labels=labels, |
|
title=title, |
|
title_fontsize="medium", |
|
labelspacing=0.6, |
|
loc="upper left", |
|
fontsize="x-small", |
|
frameon=False, |
|
) |
|
leg1._legend_box.align = "left" |
|
|
|
leg2 = ax.legend( |
|
title=labels_values_title, |
|
title_fontsize="small", |
|
loc="lower left", |
|
frameon=False, |
|
) |
|
leg2._legend_box.align = "left" |
|
|
|
ax.add_artist(leg1) |
|
|
|
|
|
def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images): |
|
|
|
sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax) |
|
|
|
|
|
ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium") |
|
|
|
|
|
ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium") |
|
|
|
|
|
sns.despine(ax=ax, left=True, bottom=True) |
|
|
|
annotate_scatterplot(ax, dict_images, x_metric, y_metric) |
|
|
|
|
|
def scatterplot_metrics(ax, df, df_all, dict_images, plot_all=False): |
|
|
|
|
|
if plot_all: |
|
sns.scatterplot( |
|
data=df_all.loc[df_all.ground == True], |
|
x="error", y="f05", hue="edge_coherence", ax=ax, |
|
marker='+', alpha=0.25) |
|
sns.scatterplot( |
|
data=df_all.loc[df_all.instagan == True], |
|
x="error", y="f05", hue="edge_coherence", ax=ax, |
|
marker='x', alpha=0.25) |
|
sns.scatterplot( |
|
data=df_all.loc[(df_all.instagan == False) & (df_all.instagan == False) & |
|
(df_all.model_feats != args.best_model)], |
|
x="error", y="f05", hue="edge_coherence", ax=ax, |
|
marker='s', alpha=0.25) |
|
|
|
|
|
cmap_ = sns.cubehelix_palette(as_cmap=True) |
|
sns.scatterplot( |
|
data=df, x="error", y="f05", hue="edge_coherence", ax=ax, palette=cmap_ |
|
) |
|
|
|
norm = plt.Normalize(df["edge_coherence"].min(), df["edge_coherence"].max()) |
|
sm = plt.cm.ScalarMappable(cmap=cmap_, norm=norm) |
|
sm.set_array([]) |
|
|
|
|
|
ax.get_legend().remove() |
|
ax_cbar = ax.figure.colorbar(sm) |
|
ax_cbar.set_label("Edge coherence", labelpad=8) |
|
|
|
|
|
ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium") |
|
|
|
|
|
ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium") |
|
|
|
annotate_scatterplot(ax, dict_images, "error", "f05") |
|
|
|
|
|
sns.despine(ax=ax, left=True, bottom=True) |
|
|
|
|
|
xlim = ax.get_xlim() |
|
ylim = ax.get_ylim() |
|
ax.set_xlim([0.0, xlim[1]]) |
|
ax.set_ylim([ylim[0], 1.0]) |
|
|
|
|
|
def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1): |
|
xlim = ax.get_xlim() |
|
ylim = ax.get_ylim() |
|
x_len = xlim[1] - xlim[0] |
|
y_len = ylim[1] - ylim[0] |
|
x_th = xlim[1] - x_len / 2.0 |
|
y_th = ylim[1] - y_len / 2.0 |
|
for text, d in dict_images.items(): |
|
if text in ["B", "D", "F"]: |
|
x = d[x_metric] |
|
y = d[y_metric] |
|
|
|
x_text = x + x_len * offset if x < x_th else x - x_len * offset |
|
y_text = y + y_len * offset if y < y_th else y - y_len * offset |
|
|
|
ax.annotate( |
|
xy=(x, y), |
|
xycoords="data", |
|
xytext=(x_text, y_text), |
|
textcoords="data", |
|
text=text, |
|
arrowprops=dict(facecolor="black", shrink=0.05), |
|
fontsize="medium", |
|
color="black", |
|
) |
|
elif text == "A": |
|
x = ( |
|
dict_images["A"][x_metric] |
|
+ dict_images["C"][x_metric] |
|
+ dict_images["E"][x_metric] |
|
) / 3 |
|
y = ( |
|
dict_images["A"][y_metric] |
|
+ dict_images["C"][y_metric] |
|
+ dict_images["E"][y_metric] |
|
) / 3 |
|
|
|
x_text = x + x_len * 2 * offset if x < x_th else x - x_len * 2 * offset |
|
y_text = ( |
|
y + y_len * 0.45 * offset if y < y_th else y - y_len * 0.45 * offset |
|
) |
|
|
|
ax.annotate( |
|
xy=(x, y), |
|
xycoords="data", |
|
xytext=(x_text, y_text), |
|
textcoords="data", |
|
text="A, C, E", |
|
arrowprops=dict(facecolor="black", shrink=0.05), |
|
fontsize="medium", |
|
color="black", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
args = parsed_args() |
|
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) |
|
|
|
|
|
if args.output_dir is None: |
|
output_dir = Path(os.environ["SLURM_TMPDIR"]) |
|
else: |
|
output_dir = Path(args.output_dir) |
|
if not output_dir.exists(): |
|
output_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
output_yml = output_dir / "labels.yml" |
|
with open(output_yml, "w") as f: |
|
yaml.dump(vars(args), f) |
|
|
|
|
|
imgs_orig_path = Path(args.masker_test_set_dir) / "imgs" |
|
labels_path = Path(args.masker_test_set_dir) / "labels" |
|
|
|
|
|
df_all = pd.read_csv(args.input_csv, index_col="model_img_idx") |
|
|
|
|
|
df = df_all.loc[df_all.model_feats == args.best_model] |
|
v_key, model_dir = df.model.unique()[0].split("/") |
|
model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir |
|
|
|
|
|
sns.reset_orig() |
|
sns.set(style="whitegrid") |
|
plt.rcParams.update({"font.family": "serif"}) |
|
plt.rcParams.update( |
|
{ |
|
"font.serif": [ |
|
"Computer Modern Roman", |
|
"Times New Roman", |
|
"Utopia", |
|
"New Century Schoolbook", |
|
"Century Schoolbook L", |
|
"ITC Bookman", |
|
"Bookman", |
|
"Times", |
|
"Palatino", |
|
"Charter", |
|
"serif" "Bitstream Vera Serif", |
|
"DejaVu Serif", |
|
] |
|
} |
|
) |
|
|
|
if args.seed: |
|
np.random.seed(args.seed) |
|
img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
|
dict_images = {} |
|
idx = 0 |
|
|
|
|
|
grid_vmargin = 0.03 |
|
ax_hspace = 0.04 |
|
ax_wspace = 0.05 |
|
n_grids = len(metrics) |
|
n_cols = 4 |
|
n_rows = 2 |
|
h_grid = (1.0 / n_grids) - ((n_grids - 1) * grid_vmargin) / n_grids |
|
|
|
fig1 = plt.figure(dpi=200, figsize=(11, 13)) |
|
|
|
n_ = 0 |
|
add_title = False |
|
for metric_id, metric in enumerate(metrics): |
|
|
|
|
|
top_grid = 1.0 - metric_id * h_grid - metric_id * grid_vmargin |
|
bottom_grid = top_grid - h_grid |
|
gridspec = GridSpec( |
|
n_rows, |
|
n_cols, |
|
wspace=ax_wspace, |
|
hspace=ax_hspace, |
|
bottom=bottom_grid, |
|
top=top_grid, |
|
) |
|
|
|
|
|
if metric == "error": |
|
ascending = True |
|
else: |
|
ascending = False |
|
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] |
|
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] |
|
img_id = img_ids[idx] |
|
dict_images.update({img_id: srs_sel}) |
|
|
|
img_filename = srs_sel.filename |
|
|
|
axes_row = [fig1.add_subplot(gridspec[0, c]) for c in range(n_cols)] |
|
if not args.no_images: |
|
n_ += 1 |
|
if metric_id == 0: |
|
add_title = True |
|
plot_images_metric( |
|
axes_row, |
|
metric, |
|
img_filename, |
|
img_id, |
|
n_, |
|
srs_sel, |
|
add_title=add_title, |
|
do_legend=False, |
|
) |
|
add_title = False |
|
|
|
idx += 1 |
|
print("1 more row done.") |
|
|
|
if metric == "error": |
|
ascending = False |
|
else: |
|
ascending = True |
|
idx_rand = np.random.permutation(int(args.percentile * len(df)))[0] |
|
srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand] |
|
img_id = img_ids[idx] |
|
dict_images.update({img_id: srs_sel}) |
|
|
|
img_filename = srs_sel.filename |
|
|
|
axes_row = [fig1.add_subplot(gridspec[1, c]) for c in range(n_cols)] |
|
if not args.no_images: |
|
n_ += 1 |
|
plot_images_metric( |
|
axes_row, |
|
metric, |
|
img_filename, |
|
img_id, |
|
n_, |
|
srs_sel, |
|
add_title=add_title, |
|
do_legend=False, |
|
) |
|
|
|
idx += 1 |
|
print("1 more row done.") |
|
|
|
output_fig = output_dir / "all_metrics.png" |
|
|
|
fig1.tight_layout() |
|
fig1.savefig(output_fig, dpi=fig1.dpi, bbox_inches="tight") |
|
|
|
|
|
fig2 = plt.figure(dpi=200) |
|
|
|
scatterplot_metrics(fig2.gca(), df, df_all, dict_images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_fig = output_dir / "scatterplots.png" |
|
fig2.savefig(output_fig, dpi=fig2.dpi, bbox_inches="tight") |
|
|