Spaces:
Running
on
L40S
Running
on
L40S
# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter | |
# SPDX-License-Identifier: MIT License license | |
# | |
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification] | |
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file]. | |
from typing import Union, List | |
import tempfile | |
import numpy as np | |
import PIL.Image | |
import matplotlib.cm as cm | |
import mediapy | |
import torch | |
from decord import VideoReader, cpu | |
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"): | |
vid = VideoReader(video_path, ctx=cpu(0)) | |
original_height, original_width = vid.get_batch([0]).shape[1:3] | |
height = original_height | |
width = original_width | |
if max_res > 0 and max(height, width) > max_res: | |
scale = max_res / max(original_height, original_width) | |
height = round(original_height * scale) | |
width = round(original_width * scale) | |
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) | |
fps = vid.get_avg_fps() if target_fps == -1 else target_fps | |
stride = round(vid.get_avg_fps() / fps) | |
stride = max(stride, 1) | |
frames_idx = list(range(0, len(vid), stride)) | |
if process_length != -1 and process_length < len(frames_idx): | |
frames_idx = frames_idx[:process_length] | |
frames = vid.get_batch(frames_idx).asnumpy() | |
return frames, fps | |
def save_video( | |
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], | |
output_video_path: str = None, | |
fps: int = 10, | |
crf: int = 18, | |
) -> str: | |
if output_video_path is None: | |
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name | |
if isinstance(video_frames[0], np.ndarray): | |
video_frames = [frame.astype(np.uint8) for frame in video_frames] | |
elif isinstance(video_frames[0], PIL.Image.Image): | |
video_frames = [np.array(frame) for frame in video_frames] | |
mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf) | |
return output_video_path | |
class ColorMapper: | |
# a color mapper to map depth values to a certain colormap | |
def __init__(self, colormap: str = "inferno"): | |
self.colormap = torch.tensor(cm.get_cmap(colormap).colors) | |
def apply(self, image: torch.Tensor, v_min=None, v_max=None): | |
# assert len(image.shape) == 2 | |
if v_min is None: | |
v_min = image.min() | |
if v_max is None: | |
v_max = image.max() | |
image = (image - v_min) / (v_max - v_min) | |
image = (image * 255).long() | |
image = self.colormap[image] * 255 | |
return image | |
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None): | |
visualizer = ColorMapper() | |
if v_min is None: | |
v_min = depths.min() | |
if v_max is None: | |
v_max = depths.max() | |
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy() | |
return res | |