|
import argparse |
|
import cv2 |
|
from . import * |
|
from .classes import * |
|
from .funcs import * |
|
from abcli import file |
|
import abcli.logging |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
parser = argparse.ArgumentParser(name, description=f"{name}-{version}") |
|
parser.add_argument( |
|
"task", |
|
type=str, |
|
default="", |
|
help="eval,ingest,list,predict,predict_image,preprocess,train", |
|
) |
|
parser.add_argument( |
|
"--objects", |
|
type=str, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--color", |
|
type=int, |
|
default=0, |
|
help="0/1", |
|
) |
|
parser.add_argument( |
|
"--convnet", |
|
type=int, |
|
default=1, |
|
help="0/1", |
|
) |
|
parser.add_argument( |
|
"--count", |
|
type=int, |
|
default=-1, |
|
) |
|
parser.add_argument( |
|
"--data_path", |
|
type=str, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--epochs", |
|
default=10, |
|
type=int, |
|
help="", |
|
) |
|
parser.add_argument( |
|
"--infer_annotation", |
|
type=int, |
|
default=1, |
|
help="0/1", |
|
) |
|
parser.add_argument( |
|
"--input_path", |
|
type=str, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--is_url", |
|
type=int, |
|
default=0, |
|
help="0/1", |
|
) |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--output_path", |
|
type=str, |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--purpose", |
|
type=str, |
|
default="", |
|
help="predict/train", |
|
) |
|
parser.add_argument( |
|
"--window_size", |
|
type=int, |
|
default=default_window_size, |
|
) |
|
args = parser.parse_args() |
|
|
|
success = False |
|
if args.task == "eval": |
|
success = eval(args.input_path, args.output_path) |
|
elif args.task == "list": |
|
Image_Classifier().load(args.model_path) |
|
success = True |
|
elif args.task == "predict": |
|
classifier = Image_Classifier() |
|
|
|
if classifier.load(args.model_path): |
|
success, test_images = file.load(f"{args.data_path}/test_images.pyndarray") |
|
|
|
if success: |
|
_, test_labels = file.load( |
|
f"{args.data_path}/test_labels.pyndarray", |
|
civilized=True, |
|
default=None, |
|
) |
|
|
|
success, prediction = classifier.predict( |
|
test_images / 255.0, |
|
test_labels, |
|
args.output_path, |
|
) |
|
elif args.task == "predict_image": |
|
success = True |
|
|
|
classifier = Image_Classifier() |
|
|
|
success = classifier.load(args.model_path) |
|
|
|
if success: |
|
if args.is_url: |
|
image_filename = file.auxiliary("image", file.extension(args.data_path)) |
|
if not file.download(args.data_path, image_filename): |
|
success = False |
|
else: |
|
image_filename = args.data_path |
|
|
|
if success: |
|
success, image = file.load_image(image_filename) |
|
|
|
if success: |
|
image = cv2.resize( |
|
image, (classifier.params["window_size"], classifier.params["window_size"]) |
|
) |
|
|
|
if not classifier.params["color"]: |
|
image = np.mean(image, axis=2) |
|
|
|
image = np.expand_dims(image, axis=0) |
|
|
|
success, prediction = classifier.predict( |
|
image / 255.0, |
|
output_path=args.output_path, |
|
) |
|
|
|
if success: |
|
index = np.argmax(prediction) |
|
logger.info( |
|
f"prediction: {classifier.class_names[index]} - {prediction[0][index]:.2f}" |
|
) |
|
elif args.task == "preprocess": |
|
success = preprocess( |
|
args.output_path, |
|
objects=args.objects, |
|
infer_annotation=args.infer_annotation, |
|
purpose=args.purpose, |
|
window_size=args.window_size, |
|
) |
|
elif args.task == "train": |
|
success = Image_Classifier.train( |
|
args.data_path, |
|
args.model_path, |
|
color=args.color, |
|
convnet=args.convnet, |
|
epochs=args.epochs, |
|
) |
|
else: |
|
logger.error(f"-{name}: {args.task}: command not found.") |
|
|
|
if not success: |
|
logger.error(f"-{name}: {args.task}: failed.") |
|
|