File size: 1,388 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf


def display_images(data, n_rows=3, n_cols=3):
    figure, axs = plt.subplots(n_rows, n_cols, figsize=(24, 12))

    axs = axs.flatten()

    plt.setp(axs, xticks=[], yticks=[])
    plt.subplots_adjust(wspace=0, hspace=0)

    for img, ax in zip(data, axs):
        img = unnormalize_if_necessary(img)
        ax.imshow(img)

    return figure


def unnormalize_if_necessary(x):
    if isinstance(x, np.ndarray):
        if x.min() < 0:
            return (x * 0.5) + 0.5
    elif isinstance(x, tf.Tensor):
        if x.numpy().min() < 0:
            return (x * 0.5) + 0.5
    return x


def display_true_pred(y_true, y_pred, n_cols=3):

    fig = plt.figure(constrained_layout=True, figsize=(24, 12))

    y_true = unnormalize_if_necessary(y_true)
    y_pred = unnormalize_if_necessary(y_pred)

    images = [y_pred, y_true]

    # create 2x1 subfigs
    subfigs = fig.subfigures(nrows=2, ncols=1)
    for row, subfig in enumerate(subfigs):
        subfig.suptitle("Prediction" if row == 0 else "Ground truth", fontsize=24)

        # create 1xn_cols subplots per subfig
        axs = subfig.subplots(nrows=1, ncols=n_cols)
        for col, ax in enumerate(axs):
            if row == 0:
                ax.imshow(images[row][col])
            else:
                ax.imshow(images[row][col])

    return fig