kamangir
validating single image predict for fashion_mnist - kamangir/bolt#692
8c7151e
raw
history blame
3.81 kB
import argparse
import cv2
from . import *
from .classes import *
from .funcs import *
from abcli import file
import abcli.logging
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(name, description=f"{name}-{version}")
parser.add_argument(
"task",
type=str,
default="",
help="eval,ingest,list,predict,predict_image,preprocess,train",
)
parser.add_argument(
"--objects",
type=str,
default="",
)
parser.add_argument(
"--color",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--convnet",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--count",
type=int,
default=-1,
)
parser.add_argument(
"--data_path",
type=str,
default="",
)
parser.add_argument(
"--epochs",
default=10,
type=int,
help="",
)
parser.add_argument(
"--infer_annotation",
type=int,
default=1,
help="0/1",
)
parser.add_argument(
"--input_path",
type=str,
default="",
)
parser.add_argument(
"--is_url",
type=int,
default=0,
help="0/1",
)
parser.add_argument(
"--model_path",
type=str,
default="",
)
parser.add_argument(
"--output_path",
type=str,
default="",
)
parser.add_argument(
"--purpose",
type=str,
default="",
help="predict/train",
)
parser.add_argument(
"--window_size",
type=int,
default=default_window_size,
)
args = parser.parse_args()
success = False
if args.task == "eval":
success = eval(args.input_path, args.output_path)
elif args.task == "list":
Image_Classifier().load(args.model_path)
success = True
elif args.task == "predict":
classifier = Image_Classifier()
if classifier.load(args.model_path):
success, test_images = file.load(f"{args.data_path}/test_images.pyndarray")
if success:
_, test_labels = file.load(
f"{args.data_path}/test_labels.pyndarray",
civilized=True,
default=None,
)
success, prediction = classifier.predict(
test_images / 255.0,
test_labels,
args.output_path,
)
elif args.task == "predict_image":
success = True
classifier = Image_Classifier()
success = classifier.load(args.model_path)
if success:
if args.is_url:
image_filename = file.auxiliary("image", file.extension(args.data_path))
if not file.download(args.data_path, image_filename):
success = False
else:
image_filename = args.data_path
if success:
success, image = file.load_image(image_filename)
if success:
image = cv2.resize(
image, (classifier.params["window_size"], classifier.params["window_size"])
)
if not classifier.params["color"]:
image = np.mean(image, axis=2)
image = np.expand_dims(image, axis=0)
success, prediction = classifier.predict(
image / 255.0,
output_path=args.output_path,
)
if success:
index = np.argmax(prediction)
logger.info(
f"prediction: {classifier.class_names[index]} - {prediction[0][index]:.2f}"
)
elif args.task == "preprocess":
success = preprocess(
args.output_path,
objects=args.objects,
infer_annotation=args.infer_annotation,
purpose=args.purpose,
window_size=args.window_size,
)
elif args.task == "train":
success = Image_Classifier.train(
args.data_path,
args.model_path,
color=args.color,
convnet=args.convnet,
epochs=args.epochs,
)
else:
logger.error(f"-{name}: {args.task}: command not found.")
if not success:
logger.error(f"-{name}: {args.task}: failed.")