kamangir
validating fashion_mnist train - kamangir/bolt#689
08b815a
raw
history blame
1.44 kB
from abcli import file
from abcli import string
from abcli.logging import crash_report
import os.path
from abcli import logging
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
def ingest(output_path):
import tensorflow as tf
try:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (
test_images,
test_labels,
) = fashion_mnist.load_data()
except:
crash_report("-fashion_mnist: ingest.")
return False
logger.info("ingesting fashion_mnist")
success = True
for name, thing in zip(
"train_images,train_labels,test_images,test_labels".split(","),
[train_images, train_labels, test_images, test_labels],
):
if file.save(os.path.join(output_path, f"{name}.pyndarray"), thing):
logger.info(f"ingested {name}: {string.pretty_shape_of_matrix(thing)}")
else:
success = False
class_names = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
if file.save_json(os.path.join(output_path, "class_names.json"), class_names):
logger.info(
f"ingested {len(class_names)} class name(s): {', '.join(class_names)}"
)
else:
success = False
return success