GANime / ganime /data /experimental.py
Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
7.55 kB
from abc import ABC, abstractclassmethod, abstractmethod
import glob
import math
import os
from typing import Dict
from typing_extensions import dataclass_transform
import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))): # if value ist tensor
value = value.numpy() # get value of tensor
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a floast_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def serialize_array(array):
array = tf.io.serialize_tensor(array)
return array
class Dataset(ABC):
def __init__(self, dataset_path: str):
self.dataset_path = dataset_path
@classmethod
def _parse_single_element(cls, element) -> tf.train.Example:
features = tf.train.Features(feature=cls._get_features(element))
return tf.train.Example(features=features)
@abstractclassmethod
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
pass
@abstractclassmethod
def _parse_tfr_element(cls, element):
pass
@classmethod
def write_to_tfr(cls, data: np.ndarray, out_dir: str, filename: str):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
# Write all elements to a single tfrecord file
single_file_name = cls.__write_to_single_tfr(data, out_dir, filename)
# The optimal size for a single tfrecord file is around 100 MB. Get the number of files that need to be created
number_splits = cls.__get_number_splits(single_file_name)
if number_splits > 1:
os.remove(single_file_name)
cls.__write_to_multiple_tfr(data, out_dir, filename, number_splits)
@classmethod
def __write_to_multiple_tfr(
cls, data: np.array, out_dir: str, filename: str, n_splits: int
):
file_count = 0
max_files = math.ceil(data.shape[0] / n_splits)
print(f"Creating {n_splits} files with {max_files} elements each.")
for i in tqdm(range(n_splits)):
current_shard_name = os.path.join(
out_dir,
f"{filename}.tfrecords-{str(i).zfill(len(str(n_splits)))}-of-{n_splits}",
)
writer = tf.io.TFRecordWriter(current_shard_name)
current_shard_count = 0
while current_shard_count < max_files: # as long as our shard is not full
# get the index of the file that we want to parse now
index = i * max_files + current_shard_count
if index >= len(
data
): # when we have consumed the whole data, preempt generation
break
current_element = data[index]
# create the required Example representation
out = cls._parse_single_element(element=current_element)
writer.write(out.SerializeToString())
current_shard_count += 1
file_count += 1
writer.close()
print(f"\nWrote {file_count} elements to TFRecord")
return file_count
@classmethod
def __get_number_splits(cls, filename: str):
target_size = 100 * 1024 * 1024 # 100mb
single_file_size = os.path.getsize(filename)
number_splits = math.ceil(single_file_size / target_size)
return number_splits
@classmethod
def __write_to_single_tfr(cls, data: np.array, out_dir: str, filename: str):
current_path_name = os.path.join(
out_dir,
f"{filename}.tfrecords-0-of-1",
)
writer = tf.io.TFRecordWriter(current_path_name)
for element in tqdm(data):
writer.write(cls._parse_single_element(element).SerializeToString())
writer.close()
return current_path_name
def load(self) -> tf.data.TFRecordDataset:
path = self.dataset_path
dataset = None
if os.path.isdir(path):
dataset = self._load_folder(path)
elif os.path.isfile(path):
dataset = self._load_file(path)
else:
raise ValueError(f"Path {path} is not a valid file or folder.")
dataset = dataset.map(self._parse_tfr_element)
return dataset
def _load_file(self, path) -> tf.data.TFRecordDataset:
return tf.data.TFRecordDataset(path)
def _load_folder(self, path) -> tf.data.TFRecordDataset:
return tf.data.TFRecordDataset(
glob.glob(os.path.join(path, "**/*.tfrecords*"), recursive=True)
)
class VideoDataset(Dataset):
@classmethod
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
return {
"frames": _int64_feature(element.shape[0]),
"height": _int64_feature(element.shape[1]),
"width": _int64_feature(element.shape[2]),
"depth": _int64_feature(element.shape[3]),
"raw_video": _bytes_feature(serialize_array(element)),
}
@classmethod
def _parse_tfr_element(cls, element):
# use the same structure as above; it's kinda an outline of the structure we now want to create
data = {
"frames": tf.io.FixedLenFeature([], tf.int64),
"height": tf.io.FixedLenFeature([], tf.int64),
"width": tf.io.FixedLenFeature([], tf.int64),
"raw_video": tf.io.FixedLenFeature([], tf.string),
"depth": tf.io.FixedLenFeature([], tf.int64),
}
content = tf.io.parse_single_example(element, data)
frames = content["frames"]
height = content["height"]
width = content["width"]
depth = content["depth"]
raw_video = content["raw_video"]
# get our 'feature'-- our image -- and reshape it appropriately
feature = tf.io.parse_tensor(raw_video, out_type=tf.uint8)
feature = tf.reshape(feature, shape=[frames, height, width, depth])
return feature
class ImageDataset(Dataset):
@classmethod
def _get_features(cls, element) -> Dict[str, tf.train.Feature]:
return {
"height": _int64_feature(element.shape[0]),
"width": _int64_feature(element.shape[1]),
"depth": _int64_feature(element.shape[2]),
"raw_image": _bytes_feature(serialize_array(element)),
}
@classmethod
def _parse_tfr_element(cls, element):
# use the same structure as above; it's kinda an outline of the structure we now want to create
data = {
"height": tf.io.FixedLenFeature([], tf.int64),
"width": tf.io.FixedLenFeature([], tf.int64),
"raw_image": tf.io.FixedLenFeature([], tf.string),
"depth": tf.io.FixedLenFeature([], tf.int64),
}
content = tf.io.parse_single_example(element, data)
height = content["height"]
width = content["width"]
depth = content["depth"]
raw_image = content["raw_image"]
# get our 'feature'-- our image -- and reshape it appropriately
feature = tf.io.parse_tensor(raw_image, out_type=tf.uint8)
feature = tf.reshape(feature, shape=[height, width, depth])
return feature