|
from pathlib import Path |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.axes_grid1 import ImageGrid |
|
from matplotlib.colors import Normalize, LogNorm |
|
import torch |
|
from torchvision.utils import make_grid |
|
from einops import rearrange |
|
from .data import blend_rgba |
|
|
|
import imageio |
|
|
|
from my.utils.plot import mpl_fig_to_buffer |
|
from my.utils.event import read_stats |
|
|
|
|
|
def vis(ref_img, pred_img, pred_depth, *, msg="", return_buffer=False): |
|
|
|
fig = plt.figure(figsize=(15, 6)) |
|
grid = ImageGrid( |
|
fig, 111, nrows_ncols=(1, 3), |
|
cbar_location="right", cbar_mode="single", |
|
) |
|
|
|
grid[0].imshow(ref_img) |
|
grid[0].set_title("gt") |
|
|
|
grid[1].imshow(pred_img) |
|
grid[1].set_title(f"rendering {msg}") |
|
|
|
h = grid[2].imshow(pred_depth, norm=LogNorm(vmin=2, vmax=10), cmap="Spectral") |
|
grid[2].set_title("expected depth") |
|
plt.colorbar(h, cax=grid.cbar_axes[0]) |
|
plt.tight_layout() |
|
|
|
if return_buffer: |
|
plot = mpl_fig_to_buffer(fig) |
|
return plot |
|
else: |
|
plt.show() |
|
|
|
|
|
def _bad_vis(pred_img, pred_depth, *, return_buffer=False): |
|
"""emergency function for one-off use""" |
|
fig, grid = plt.subplots(1, 2, squeeze=True, figsize=(10, 6)) |
|
|
|
grid[0].imshow(pred_img) |
|
grid[0].set_title("rendering") |
|
|
|
h = grid[1].imshow(pred_depth, norm=LogNorm(vmin=0.5, vmax=10), cmap="Spectral") |
|
grid[1].set_title("expected depth") |
|
|
|
plt.tight_layout() |
|
|
|
if return_buffer: |
|
plot = mpl_fig_to_buffer(fig) |
|
return plot |
|
else: |
|
plt.show() |
|
|
|
|
|
colormap = plt.get_cmap('Spectral') |
|
|
|
|
|
def bad_vis(pred_img, pred_depth, final_H=512): |
|
|
|
depth = pred_depth.cpu().numpy() |
|
del pred_depth |
|
|
|
depth = np.log(1. + depth + 1e-12) |
|
depth = depth / np.log(1+10.) |
|
|
|
depth = colormap(depth) |
|
depth = blend_rgba(depth) |
|
depth = rearrange(depth, "h w c -> 1 c h w", c=3) |
|
depth = torch.from_numpy(depth) |
|
|
|
depth = torch.nn.functional.interpolate( |
|
depth, (final_H, final_H), mode='bilinear', antialias=True |
|
) |
|
pred_img = torch.nn.functional.interpolate( |
|
pred_img, (final_H, final_H), mode='bilinear', antialias=True |
|
) |
|
pred_img = (pred_img + 1) / 2 |
|
pred_img = pred_img.clamp(0, 1).cpu() |
|
stacked = torch.cat([pred_img, depth], dim=0) |
|
pane = make_grid(stacked, nrow=2) |
|
pane = rearrange(pane, "c h w -> h w c") |
|
pane = (pane * 255.).clamp(0, 255) |
|
pane = pane.to(torch.uint8) |
|
pane = pane.numpy() |
|
|
|
|
|
return pane |
|
|
|
|
|
def export_movie(seqs, fname, fps=30): |
|
fname = Path(fname) |
|
if fname.suffix == "": |
|
fname = fname.with_suffix(".mp4") |
|
writer = imageio.get_writer(fname, fps=fps) |
|
for img in seqs: |
|
writer.append_data(img) |
|
writer.close() |
|
|
|
|
|
def stitch_vis(save_fn, img_fnames, fps=10): |
|
figs = [imageio.imread(fn) for fn in img_fnames] |
|
export_movie(figs, save_fn, fps) |
|
|