conv-lstm / app.py
nouamanetazi's picture
nouamanetazi HF staff
add credits
fa8b0cf
raw
history blame contribute delete
No virus
3.99 kB
import os
import yaml
import gradio as gr
import numpy as np
import imageio, cv2
from moviepy.editor import *
from skimage.transform import resize
from skimage import img_as_ubyte
from skimage.color import rgb2gray
from huggingface_hub.keras_mixin import from_pretrained_keras
# load model
model = from_pretrained_keras("keras-io/conv-lstm")
# Examples
samples = []
example_source = os.listdir('asset/source')
for video in example_source:
samples.append([f'asset/source/{video}', 0.5, True])
def inference(source,
split_pred = 0.4, # predict 0.6% of video
predict_one = False, # Whether to predict a sliding one frame or all frames at once
output_name = 'output.mp4',
output_path = 'asset/output',
cpu = False,
):
# source
reader = imageio.get_reader(source)
fps = reader.get_meta_data()['fps']
source_video = []
try:
for im in reader:
source_video.append(im)
except RuntimeError:
pass
reader.close()
source_video = [rgb2gray(resize(frame, (64, 64)))[..., np.newaxis] for frame in source_video]
example = np.array(source_video)
print(example.shape)
# Pick the first/last ten frames from the example.
start_pred_id = int(split_pred * example.shape[0]) # prediction starts from frame start_pred_id
frames = example[:start_pred_id, ...]
original_frames = example[start_pred_id:, ...]
new_predictions = np.zeros(shape=(example.shape[0] - start_pred_id, *frames[0].shape))
# Predict a new set of 10 frames.
for i in range(example.shape[0] - start_pred_id):
# Extract the model's prediction and post-process it.
if predict_one:
frames = example[: start_pred_id + i + 1, ...]
else:
frames = np.concatenate((example[: start_pred_id+1 , ...], new_predictions[:i, ...]), axis=0)
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# Extend the set of prediction frames.
new_predictions[i] = predicted_frame
# Create and save MP4s for each of the ground truth/prediction images.
def postprocess(frame_set, save_file):
# Construct a GIF from the selected video frames.
current_frames = np.squeeze(frame_set)
current_frames = current_frames[..., np.newaxis] * np.ones(3)
current_frames = (current_frames * 255).astype(np.uint8)
current_frames = list(current_frames)
print(f'{output_path}/{save_file}')
imageio.mimsave(f'{output_path}/{save_file}', current_frames, fps=fps)
# save video
os.makedirs(output_path, exist_ok=True)
postprocess(original_frames, "original.mp4")
postprocess(new_predictions, output_name)
return f'{output_path}/{output_name}', f'{output_path}/original.mp4'
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/conv_lstm/' target='_blank'>Keras example by Amogh Joshi</a></div>"
iface = gr.Interface(
inference, # main function
inputs = [
gr.inputs.Video(label='Video', type='mp4'),
gr.inputs.Slider(minimum=.1, maximum=.9, default=.5, step=.001, label="prediction start"),
gr.inputs.Checkbox(label="predict one frame only", default=True),
],
outputs = [
gr.outputs.Video(label='result'), # generated video
gr.outputs.Video(label='ground truth') # same part of original video
],
title = 'Next-Frame Video Prediction with Convolutional LSTMs',
# description = "This app is an unofficial demo web app of the Next-Frame Video Prediction with Convolutional LSTMs by Keras.",
article = article,
examples = samples,
).launch(enable_queue=True, cache_examples=True)