lipnet / .ipynb_checkpoints /train-checkpoint.py
milselarch's picture
Upload folder using huggingface_hub
3a3c68a
raw
history blame
1.12 kB
import os
import cv2
import tensorflow as tf
import numpy as np
import imageio
import yaml
from matplotlib import pyplot as plt
from helpers import *
from typing import List
from Loader import GridLoader
with open('config.yml', 'r') as config_file_obj:
yaml_config = yaml.safe_load(config_file_obj)
dataset_config = yaml_config['datasets']
VIDEO_DIR = dataset_config['video_dir']
ALIGNMENTS_DIR = dataset_config['alignments_dir']
loader = GridLoader()
data = tf.data.Dataset.from_tensor_slices(loader.load_videos())
# print('DATA', data)
# List to store filenames
filenames = []
# Iterate over the dataset to get all filenames
for file_path in data:
filenames.append(file_path.numpy().decode("utf-8"))
# print(filenames)
data = data.shuffle(500, reshuffle_each_iteration=False)
data = data.map(mappable_function)
data = data.padded_batch(2, padded_shapes=(
[75, None, None, None], [40]
))
data = data.prefetch(tf.data.AUTOTUNE)
# Added for split
train = data.take(450)
test = data.skip(450)
# print(load_data('GRID-dataset/videos/s1/briz8p.mpg'))
frames, alignments = data.as_numpy_iterator().next()