Spaces:
Running
on
L40S
Running
on
L40S
# Copyright 2022 Google LLC | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utility functions for frame interpolation on a set of video frames.""" | |
import os | |
import shutil | |
from typing import Generator, Iterable, List, Optional | |
from . import interpolator as interpolator_lib | |
import numpy as np | |
import tensorflow as tf | |
from tqdm import tqdm | |
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) | |
_CONFIG_FFMPEG_NAME_OR_PATH = 'ffmpeg' | |
def read_image(filename: str) -> np.ndarray: | |
"""Reads an sRgb 8-bit image. | |
Args: | |
filename: The input filename to read. | |
Returns: | |
A float32 3-channel (RGB) ndarray with colors in the [0..1] range. | |
""" | |
image_data = tf.io.read_file(filename) | |
image = tf.io.decode_image(image_data, channels=3) | |
image_numpy = tf.cast(image, dtype=tf.float32).numpy() | |
return image_numpy / _UINT8_MAX_F | |
def write_image(filename: str, image: np.ndarray) -> None: | |
"""Writes a float32 3-channel RGB ndarray image, with colors in range [0..1]. | |
Args: | |
filename: The output filename to save. | |
image: A float32 3-channel (RGB) ndarray with colors in the [0..1] range. | |
""" | |
image_in_uint8_range = np.clip(image * _UINT8_MAX_F, 0.0, _UINT8_MAX_F) | |
image_in_uint8 = (image_in_uint8_range + 0.5).astype(np.uint8) | |
extension = os.path.splitext(filename)[1] | |
if extension == '.jpg': | |
image_data = tf.io.encode_jpeg(image_in_uint8) | |
else: | |
image_data = tf.io.encode_png(image_in_uint8) | |
tf.io.write_file(filename, image_data) | |
def _recursive_generator( | |
frame1: np.ndarray, frame2: np.ndarray, num_recursions: int, | |
interpolator: interpolator_lib.Interpolator, | |
bar: Optional[tqdm] = None | |
) -> Generator[np.ndarray, None, None]: | |
"""Splits halfway to repeatedly generate more frames. | |
Args: | |
frame1: Input image 1. | |
frame2: Input image 2. | |
num_recursions: How many times to interpolate the consecutive image pairs. | |
interpolator: The frame interpolator instance. | |
Yields: | |
The interpolated frames, including the first frame (frame1), but excluding | |
the final frame2. | |
""" | |
if num_recursions == 0: | |
yield frame1 | |
else: | |
# Adds the batch dimension to all inputs before calling the interpolator, | |
# and remove it afterwards. | |
time = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) | |
mid_frame = interpolator(frame1[np.newaxis, ...], frame2[np.newaxis, ...], | |
time)[0] | |
bar.update(1) if bar is not None else bar | |
yield from _recursive_generator(frame1, mid_frame, num_recursions - 1, | |
interpolator, bar) | |
yield from _recursive_generator(mid_frame, frame2, num_recursions - 1, | |
interpolator, bar) | |
def interpolate_recursively_from_files( | |
frames: List[str], times_to_interpolate: int, | |
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]: | |
"""Generates interpolated frames by repeatedly interpolating the midpoint. | |
Loads the files on demand and uses the yield paradigm to return the frames | |
to allow streamed processing of longer videos. | |
Recursive interpolation is useful if the interpolator is trained to predict | |
frames at midpoint only and is thus expected to perform poorly elsewhere. | |
Args: | |
frames: List of input frames. Expected shape (H, W, 3). The colors should be | |
in the range[0, 1] and in gamma space. | |
times_to_interpolate: Number of times to do recursive midpoint | |
interpolation. | |
interpolator: The frame interpolation model to use. | |
Yields: | |
The interpolated frames (including the inputs). | |
""" | |
n = len(frames) | |
num_frames = (n - 1) * (2**(times_to_interpolate) - 1) | |
bar = tqdm(total=num_frames, ncols=100, colour='green') | |
for i in range(1, n): | |
yield from _recursive_generator( | |
read_image(frames[i - 1]), read_image(frames[i]), times_to_interpolate, | |
interpolator, bar) | |
# Separately yield the final frame. | |
yield read_image(frames[-1]) | |
def interpolate_recursively_from_memory( | |
frames: List[np.ndarray], times_to_interpolate: int, | |
interpolator: interpolator_lib.Interpolator) -> Iterable[np.ndarray]: | |
"""Generates interpolated frames by repeatedly interpolating the midpoint. | |
This is functionally equivalent to interpolate_recursively_from_files(), but | |
expects the inputs frames in memory, instead of loading them on demand. | |
Recursive interpolation is useful if the interpolator is trained to predict | |
frames at midpoint only and is thus expected to perform poorly elsewhere. | |
Args: | |
frames: List of input frames. Expected shape (H, W, 3). The colors should be | |
in the range[0, 1] and in gamma space. | |
times_to_interpolate: Number of times to do recursive midpoint | |
interpolation. | |
interpolator: The frame interpolation model to use. | |
Yields: | |
The interpolated frames (including the inputs). | |
""" | |
n = len(frames) | |
num_frames = (n - 1) * (2**(times_to_interpolate) - 1) | |
bar = tqdm(total=num_frames, ncols=100, colour='green') | |
for i in range(1, n): | |
yield from _recursive_generator(frames[i - 1], frames[i], | |
times_to_interpolate, interpolator, bar) | |
# Separately yield the final frame. | |
yield frames[-1] | |
def get_ffmpeg_path() -> str: | |
path = shutil.which(_CONFIG_FFMPEG_NAME_OR_PATH) | |
if not path: | |
raise RuntimeError( | |
f"Program '{_CONFIG_FFMPEG_NAME_OR_PATH}' is not found;" | |
" perhaps install ffmpeg using 'apt-get install ffmpeg'.") | |
return path | |