from abcli import file | |
from abcli import string | |
from abcli.logging import crash_report | |
import os.path | |
from abcli import logging | |
import tensorflow as tf | |
import logging | |
logger = logging.getLogger(__name__) | |
def ingest(output_path): | |
import tensorflow as tf | |
try: | |
fashion_mnist = tf.keras.datasets.fashion_mnist | |
(train_images, train_labels), ( | |
test_images, | |
test_labels, | |
) = fashion_mnist.load_data() | |
except: | |
crash_report("-fashion_mnist: ingest.") | |
return False | |
logger.info("ingesting fashion_mnist") | |
success = True | |
for name, thing in zip( | |
"train_images,train_labels,test_images,test_labels".split(","), | |
[train_images, train_labels, test_images, test_labels], | |
): | |
if file.save(os.path.join(output_path, f"{name}.pyndarray"), thing): | |
logger.info(f"ingested {name}: {string.pretty_shape_of_matrix(thing)}") | |
else: | |
success = False | |
class_names = [ | |
"T-shirt/top", | |
"Trouser", | |
"Pullover", | |
"Dress", | |
"Coat", | |
"Sandal", | |
"Shirt", | |
"Sneaker", | |
"Bag", | |
"Ankle boot", | |
] | |
if file.save_json(os.path.join(output_path, "class_names.json"), class_names): | |
logger.info( | |
f"ingested {len(class_names)} class name(s): {', '.join(class_names)}" | |
) | |
else: | |
success = False | |
return success | |