|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
def display_images(data, n_rows=3, n_cols=3): |
|
figure, axs = plt.subplots(n_rows, n_cols, figsize=(24, 12)) |
|
|
|
axs = axs.flatten() |
|
|
|
plt.setp(axs, xticks=[], yticks=[]) |
|
plt.subplots_adjust(wspace=0, hspace=0) |
|
|
|
for img, ax in zip(data, axs): |
|
img = unnormalize_if_necessary(img) |
|
ax.imshow(img) |
|
|
|
return figure |
|
|
|
|
|
def unnormalize_if_necessary(x): |
|
if isinstance(x, np.ndarray): |
|
if x.min() < 0: |
|
return (x * 0.5) + 0.5 |
|
elif isinstance(x, tf.Tensor): |
|
if x.numpy().min() < 0: |
|
return (x * 0.5) + 0.5 |
|
return x |
|
|
|
|
|
def display_true_pred(y_true, y_pred, n_cols=3): |
|
|
|
fig = plt.figure(constrained_layout=True, figsize=(24, 12)) |
|
|
|
y_true = unnormalize_if_necessary(y_true) |
|
y_pred = unnormalize_if_necessary(y_pred) |
|
|
|
images = [y_pred, y_true] |
|
|
|
|
|
subfigs = fig.subfigures(nrows=2, ncols=1) |
|
for row, subfig in enumerate(subfigs): |
|
subfig.suptitle("Prediction" if row == 0 else "Ground truth", fontsize=24) |
|
|
|
|
|
axs = subfig.subplots(nrows=1, ncols=n_cols) |
|
for col, ax in enumerate(axs): |
|
if row == 0: |
|
ax.imshow(images[row][col]) |
|
else: |
|
ax.imshow(images[row][col]) |
|
|
|
return fig |
|
|