memef4rmer's picture
Duplicate from jyseo/3DFuse
efe5745
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from matplotlib.colors import Normalize, LogNorm
import torch
from torchvision.utils import make_grid
from einops import rearrange
from .data import blend_rgba
import imageio
from my.utils.plot import mpl_fig_to_buffer
from my.utils.event import read_stats
def vis(ref_img, pred_img, pred_depth, *, msg="", return_buffer=False):
# plt the 2 images side by side and compare
fig = plt.figure(figsize=(15, 6))
grid = ImageGrid(
fig, 111, nrows_ncols=(1, 3),
cbar_location="right", cbar_mode="single",
)
grid[0].imshow(ref_img)
grid[0].set_title("gt")
grid[1].imshow(pred_img)
grid[1].set_title(f"rendering {msg}")
h = grid[2].imshow(pred_depth, norm=LogNorm(vmin=2, vmax=10), cmap="Spectral")
grid[2].set_title("expected depth")
plt.colorbar(h, cax=grid.cbar_axes[0])
plt.tight_layout()
if return_buffer:
plot = mpl_fig_to_buffer(fig)
return plot
else:
plt.show()
def _bad_vis(pred_img, pred_depth, *, return_buffer=False):
"""emergency function for one-off use"""
fig, grid = plt.subplots(1, 2, squeeze=True, figsize=(10, 6))
grid[0].imshow(pred_img)
grid[0].set_title("rendering")
h = grid[1].imshow(pred_depth, norm=LogNorm(vmin=0.5, vmax=10), cmap="Spectral")
grid[1].set_title("expected depth")
# plt.colorbar(h, cax=grid.cbar_axes[0])
plt.tight_layout()
if return_buffer:
plot = mpl_fig_to_buffer(fig)
return plot
else:
plt.show()
colormap = plt.get_cmap('Spectral')
def bad_vis(pred_img, pred_depth, final_H=512):
# pred_img = pred_img.cpu()
depth = pred_depth.cpu().numpy()
del pred_depth
depth = np.log(1. + depth + 1e-12)
depth = depth / np.log(1+10.)
# depth = 1 - depth
depth = colormap(depth)
depth = blend_rgba(depth)
depth = rearrange(depth, "h w c -> 1 c h w", c=3)
depth = torch.from_numpy(depth)
depth = torch.nn.functional.interpolate(
depth, (final_H, final_H), mode='bilinear', antialias=True
)
pred_img = torch.nn.functional.interpolate(
pred_img, (final_H, final_H), mode='bilinear', antialias=True
)
pred_img = (pred_img + 1) / 2
pred_img = pred_img.clamp(0, 1).cpu()
stacked = torch.cat([pred_img, depth], dim=0)
pane = make_grid(stacked, nrow=2)
pane = rearrange(pane, "c h w -> h w c")
pane = (pane * 255.).clamp(0, 255)
pane = pane.to(torch.uint8)
pane = pane.numpy()
# plt.imshow(pane)
# plt.show()
return pane
def export_movie(seqs, fname, fps=30):
fname = Path(fname)
if fname.suffix == "":
fname = fname.with_suffix(".mp4")
writer = imageio.get_writer(fname, fps=fps)
for img in seqs:
writer.append_data(img)
writer.close()
def stitch_vis(save_fn, img_fnames, fps=10):
figs = [imageio.imread(fn) for fn in img_fnames]
export_movie(figs, save_fn, fps)