kamangir
validating fashion_mnist train - kamangir/bolt#689
5877e39
raw
history blame
3.23 kB
import argparse
from . import *
from .classes import *
from .funcs import *
from abcli import file
import abcli.logging
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(name, description=f"{name}-{version}")
parser.add_argument(
"task",
type=str,
default="",
help="describe,eval,ingest,predict,preprocess,train",
)
parser.add_argument(
"--objects",
type=str,
default="",
)
parser.add_argument(
"--color",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--convnet",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--count",
type=int,
default=-1,
)
parser.add_argument(
"--data_path",
type=str,
default="",
)
parser.add_argument(
"--epochs",
default=10,
type=int,
help="",
)
parser.add_argument(
"--exclude",
type=str,
default="",
)
parser.add_argument(
"--include",
type=str,
default="",
)
parser.add_argument(
"--infer_annotation",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--input_path",
type=str,
default="",
)
parser.add_argument(
"--model_path",
type=str,
default="",
)
parser.add_argument(
"--negative",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--non_empty",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--output_path",
type=str,
default="",
)
parser.add_argument(
"--positive",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--purpose",
type=str,
default="",
help="predict/train",
)
parser.add_argument(
"--test_size",
type=float,
default=1.0 / 6,
)
parser.add_argument(
"--window_size",
type=int,
default=default_window_size,
)
args = parser.parse_args()
success = False
if args.task == "describe":
Image_Classifier().load(args.model_path)
success = True
elif args.task == "eval":
success = eval(args.input_path, args.output_path)
elif args.task == "predict":
classifier = Image_Classifier()
if classifier.load(args.model_path):
success, test_images = file.load(f"{args.data_path}/test_images.pyndarray")
if success:
logger.info(f"test_images: {string.pretty_shape_of_matrix(test_images)}")
_, test_labels = file.load(
"{}/test_labels.pyndarray".format(args.data_path),
civilized=True,
default=None,
)
test_images = test_images / 255.0
success = classifier.predict(
test_images,
test_labels,
args.output_path,
)
elif args.task == "preprocess":
success = preprocess(
args.output_path,
objects=args.objects,
infer_annotation=args.infer_annotation,
purpose=args.purpose,
window_size=args.window_size,
)
elif args.task == "train":
success = Image_Classifier.train(
args.data_path,
args.model_path,
color=args.color,
convnet=args.convnet,
epochs=args.epochs,
)
else:
logger.error(f"-{name}: {args.task}: command not found.")
if not success:
logger.error(f"-{name}: {args.task}: failed.")