import numpy as np from tqdm.auto import tqdm import tensorflow as tf import tensorflow_datasets as tfds def dataset_statistics(ds): if isinstance(ds, tf.data.Dataset): ds_numpy = tfds.as_numpy(ds) elif isinstance(ds, tf.keras.utils.Sequence): ds_numpy = ds data = [] for da in tqdm(ds_numpy): X, y = da data.append(X) all_data = np.concatenate(data) return np.mean(all_data), np.var(all_data), np.std(all_data)