kamangir
fashion-mnist -> image-classifier - kamangir/bolt#689
fc0b387
raw
history blame
4.03 kB
import argparse
import cv2
from functools import reduce
import matplotlib.pyplot as plt
import numpy as np
import os
import os.path
import tensorflow as tf
from tqdm import *
import re
import time
from . import *
from abcli import objects
from abcli import cache
from abcli import file
from abcli.tasks import host
from abcli import graphics
from abcli.options import Options
from abcli import path
from abcli.storage import instance as storage
from abcli import string
from abcli.plugins import tags
import abcli.logging
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(name, description=f"{name}-{version}")
parser.add_argument(
"task",
type=str,
default="",
help="describe,eval,ingest,predict,preprocess,train",
)
parser.add_argument(
"--objects",
type=str,
default="",
)
parser.add_argument(
"--color",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--convnet",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--count",
type=int,
default=-1,
)
parser.add_argument(
"--data_path",
type=str,
default="",
)
parser.add_argument(
"--epochs",
default=10,
type=int,
help="",
)
parser.add_argument(
"--exclude",
type=str,
default="",
)
parser.add_argument(
"--include",
type=str,
default="",
)
parser.add_argument(
"--infer_annotation",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--input_path",
type=str,
default="",
)
parser.add_argument(
"--model_path",
type=str,
default="",
)
parser.add_argument(
"--negative",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--non_empty",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--output_path",
type=str,
default="",
)
parser.add_argument(
"--positive",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--purpose",
type=str,
default="",
help="predict/train",
)
parser.add_argument(
"--test_size",
type=float,
default=1.0 / 6,
)
parser.add_argument(
"--window_size",
type=int,
default=28,
)
args = parser.parse_args()
success = False
if args.task == "describe":
image_classifier().load(args.model_path)
success = True
elif args.task == "eval":
success = eval(args.input_path, args.output_path)
elif args.task == "ingest":
success = ingest(
args.include,
args.output_path,
{
"count": args.count,
"exclude": args.exclude,
"negative": args.negative,
"non_empty": args.non_empty,
"positive": args.positive,
"test_size": args.test_size,
},
)
elif args.task == "predict":
classifier = image_classifier()
if classifier.load(args.model_path):
success, test_images = file.load(
"{}/test_images.pyndarray".format(args.data_path)
)
if success:
logger.info("test_images: {}".format(string.pretty_size_of_matrix(test_images)))
_, test_labels = file.load(
"{}/test_labels.pyndarray".format(args.data_path),
civilized=True,
default=None,
)
test_images = test_images / 255.0
success = classifier.predict(test_images, test_labels, args.output_path)
elif args.task == "preprocess":
success = preprocess(
args.output_path,
{
"objects": args.objects,
"infer_annotation": args.infer_annotation,
"purpose": args.purpose,
"window_size": args.window_size,
},
)
elif args.task == "train":
classifier = image_classifier()
success = classifier.train(
args.data_path,
args.model_path,
{"color": args.color, "convnet": args.convnet, "epochs": args.epochs},
)
else:
logger.error(f"-{name}: {args.task}: command not found.")
if not success:
logger.error(f"-{name}: {args.task}: failed.")