conv-lstm / app.py
nouamanetazi's picture
nouamanetazi HF staff
add app.py
d65aa43
raw
history blame
No virus
3.7 kB
import os
import yaml
import gradio as gr
import numpy as np
import imageio, cv2
from moviepy.editor import *
from skimage.transform import resize
from skimage import img_as_ubyte
from skimage.color import rgb2gray
from tensorflow import keras
# load model
model = keras.models.load_model('saved_model')
# Examples
samples = []
example_driving = os.listdir('asset/driving')
for video in example_driving:
samples.append([f'asset/driving/{video}', 0.5, False])
def inference(driving,
split_pred = 0.4, # predict 0.6% of video
predict_one = False, # Whether to predict a sliding one frame or all frames at once
output_name = 'output.mp4',
output_path = 'asset/output',
cpu = False,
):
# driving
reader = imageio.get_reader(driving)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [rgb2gray(resize(frame, (64, 64)))[..., np.newaxis] for frame in driving_video]
example = np.array(driving_video)
print(example.shape)
# Pick the first/last ten frames from the example.
start_pred_id = int(split_pred * example.shape[0]) # prediction starts from frame start_pred_id
frames = example[:start_pred_id, ...]
original_frames = example[start_pred_id:, ...]
new_predictions = np.zeros(shape=(example.shape[0] - start_pred_id, *frames[0].shape))
# Predict a new set of 10 frames.
for i in range(example.shape[0] - start_pred_id):
# Extract the model's prediction and post-process it.
if predict_one:
frames = example[: start_pred_id + i + 1, ...]
else:
frames = np.concatenate((example[: start_pred_id+1 , ...], new_predictions[:i, ...]), axis=0)
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
new_predictions[i] = predicted_frame
# Create and save MP4s for each of the ground truth/prediction images.
def postprocess(frame_set, save_file):
# Construct a GIF from the selected video frames.
current_frames = np.squeeze(frame_set)
current_frames = current_frames[..., np.newaxis] * np.ones(3)
current_frames = (current_frames * 255).astype(np.uint8)
current_frames = list(current_frames)
print(f'{output_path}/{save_file}')
imageio.mimsave(f'{output_path}/{save_file}', current_frames, fps=fps)
# save video
os.makedirs(output_path, exist_ok=True)
postprocess(original_frames, "original.mp4")
postprocess(new_predictions, output_name)
return f'{output_path}/{output_name}', f'{output_path}/original.mp4'
iface = gr.Interface(
inference, # main function
inputs = [
gr.inputs.Video(label='Video', type='mp4'),
gr.inputs.Slider(minimum=.1, maximum=.9, default=.5, step=.001, label="prediction start"),
gr.inputs.Checkbox(label="predict one frame only", default=False),
],
outputs = [
gr.outputs.Video(label='result'), # generated video
gr.outputs.Video(label='ground truth') # same part of original video
],
title = 'Next-Frame Video Prediction with Convolutional LSTMs',
description = "This app is an unofficial demo web app of the Next-Frame Video Prediction with Convolutional LSTMs by Keras.",
examples = samples,
).launch(enable_queue=True, cache_examples=True)