lipnet / helpers.py
milselarch's picture
Upload folder using huggingface_hub
3a3c68a
raw
history blame
No virus
2.4 kB
import os
import cv2
import tensorflow as tf
import numpy as np
import yaml
from typing import List
with open('config.yml', 'r') as config_file_obj:
yaml_config = yaml.safe_load(config_file_obj)
dataset_config = yaml_config['datasets']
VIDEO_DIR = dataset_config['video_dir']
ALIGNMENTS_DIR = dataset_config['alignments_dir']
vocab = [x for x in "abcdefghijklmnopqrstuvwxyz'?!123456789 "]
char_to_num = tf.keras.layers.StringLookup(vocabulary=vocab, oov_token="")
num_to_char = tf.keras.layers.StringLookup(
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)
def load_video(path: str) -> List[float]:
cap = cv2.VideoCapture(path)
frames = []
for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
ret, frame = cap.read()
frame = tf.image.rgb_to_grayscale(frame)
frames.append(frame[190:236, 80:220, :])
cap.release()
mean = tf.math.reduce_mean(frames)
std = tf.math.reduce_std(tf.cast(frames, tf.float32))
return tf.cast((frames - mean), tf.float32) / std
def load_alignments(path: str) -> List[str]:
with open(path, 'r') as f:
lines = f.readlines()
tokens = []
for line in lines:
line = line.split()
if line[2] != 'sil':
tokens = [*tokens, ' ', line[2]]
return char_to_num(tf.reshape(
tf.strings.unicode_split(tokens, input_encoding='UTF-8'), (-1)
))[1:]
def load_data(tf_path):
# print('PATH', tf_path, type(tf_path))
path = tf_path.numpy().decode('utf-8')
# get dirname of dir
dir_name = os.path.basename(os.path.dirname(path))
# get filename of the current path
base_filename = os.path.basename(path)
base_name = os.path.splitext(base_filename)[0]
new_base_path = os.path.join(dir_name, base_name)
# file_name = path.split('/')[-1].split('.')[0]
# File name splitting for windows
video_path = os.path.join(VIDEO_DIR, f'{new_base_path}.mpg')
alignment_path = os.path.join(
ALIGNMENTS_DIR, f'{new_base_path}.align'
)
try:
frames = load_video(video_path)
except Exception as e:
print('BAD_VIDEO', video_path)
raise e
alignments = load_alignments(alignment_path)
return frames, alignments
def mappable_function(path:str) -> List[str]:
result = tf.py_function(
load_data, [path], (tf.float32, tf.int64)
)
return result