kamangir
commited on
Commit
•
d2b2be7
1
Parent(s):
cd2efd1
validating single image predict for fashion_mnist - kamangir/bolt#692
Browse files- abcli/image_classifier.sh +42 -11
- image_classifier/__init__.py +1 -1
- image_classifier/__main__.py +30 -1
abcli/image_classifier.sh
CHANGED
@@ -68,12 +68,26 @@ function abcli_image_classifier_predict() {
|
|
68 |
local data_source=$(abcli_option "$options" "data" object)
|
69 |
local model_source=$(abcli_option "$options" "model" saved)
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
if [ "$data_source" == "object" ] ; then
|
74 |
abcli_download object $data_object
|
75 |
fi
|
76 |
|
|
|
|
|
|
|
77 |
if [ "$model_source" == "object" ] ; then
|
78 |
local model_object=TBD
|
79 |
abcli_download object $model_object
|
@@ -81,7 +95,7 @@ function abcli_image_classifier_predict() {
|
|
81 |
|
82 |
abcli_log "image_classifier($model_path).predict($data_object): $options"
|
83 |
|
84 |
-
if [ ! -f "$abcli_object_root/$data_object/test_images.pyndarray" ] ; then
|
85 |
python3 -m image_classifier \
|
86 |
preprocess \
|
87 |
--infer_annotation 0 \
|
@@ -92,17 +106,34 @@ function abcli_image_classifier_predict() {
|
|
92 |
${@:4}
|
93 |
fi
|
94 |
|
95 |
-
|
96 |
-
cp -v $model_path/image_classifier/model/class_names.json .
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
}
|
107 |
|
108 |
function abcli_image_classifier_train() {
|
|
|
68 |
local data_source=$(abcli_option "$options" "data" object)
|
69 |
local model_source=$(abcli_option "$options" "model" saved)
|
70 |
|
71 |
+
if [ "$(abcli_keyword_is $data_object validate)" == true ] ; then
|
72 |
+
if [ "$data_source" == "object" ] ; then
|
73 |
+
abcli_log_error "-imge-classifier: predict: validation object not found."
|
74 |
+
return
|
75 |
+
fi
|
76 |
+
|
77 |
+
if [ "$data_source" == "url" ] ; then
|
78 |
+
local data_object="https://upload.wikimedia.org/wikipedia/commons/4/45/Red_High_Heel_Pumps.jpg"
|
79 |
+
else
|
80 |
+
local data_object="a-validation-filename"
|
81 |
+
fi
|
82 |
+
fi
|
83 |
|
84 |
if [ "$data_source" == "object" ] ; then
|
85 |
abcli_download object $data_object
|
86 |
fi
|
87 |
|
88 |
+
|
89 |
+
local model_path=$(abcli_huggingface get_model_path image-classifier "$model_name" "$options")
|
90 |
+
|
91 |
if [ "$model_source" == "object" ] ; then
|
92 |
local model_object=TBD
|
93 |
abcli_download object $model_object
|
|
|
95 |
|
96 |
abcli_log "image_classifier($model_path).predict($data_object): $options"
|
97 |
|
98 |
+
if [ ! -f "$abcli_object_root/$data_object/test_images.pyndarray" ] && [ "$data_source" == "object" ] ; then
|
99 |
python3 -m image_classifier \
|
100 |
preprocess \
|
101 |
--infer_annotation 0 \
|
|
|
106 |
${@:4}
|
107 |
fi
|
108 |
|
109 |
+
return
|
|
|
110 |
|
111 |
+
if [ "$data_source" == "object" ] ; then
|
112 |
+
cp -v $abcli_object_root/$data_object/*.pyndarray .
|
113 |
+
cp -v $model_path/image_classifier/model/class_names.json .
|
114 |
+
|
115 |
+
python3 -m image_classifier \
|
116 |
+
predict \
|
117 |
+
--data_path $abcli_object_root/$data_object \
|
118 |
+
--model_path $model_path \
|
119 |
+
--output_path $abcli_object_path \
|
120 |
+
${@:4}
|
121 |
|
122 |
+
abcli_tag set . image_classifier,predict
|
123 |
+
else
|
124 |
+
local is_url=0
|
125 |
+
if [ "$data_source" == "url" ] ; then
|
126 |
+
local is_url=1
|
127 |
+
fi
|
128 |
+
|
129 |
+
python3 -m image_classifier \
|
130 |
+
predict_image \
|
131 |
+
--data_path $data_object \
|
132 |
+
--is_url $is_url \
|
133 |
+
--model_path $model_path \
|
134 |
+
--output_path $abcli_object_path \
|
135 |
+
${@:4}
|
136 |
+
fi
|
137 |
}
|
138 |
|
139 |
function abcli_image_classifier_train() {
|
image_classifier/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
name = "image_classifier"
|
2 |
|
3 |
-
version = "1.1.
|
4 |
|
5 |
description = "fashion-mnist + hugging-face + awesome-bash-cli"
|
|
|
1 |
name = "image_classifier"
|
2 |
|
3 |
+
version = "1.1.145"
|
4 |
|
5 |
description = "fashion-mnist + hugging-face + awesome-bash-cli"
|
image_classifier/__main__.py
CHANGED
@@ -14,7 +14,7 @@ parser.add_argument(
|
|
14 |
"task",
|
15 |
type=str,
|
16 |
default="",
|
17 |
-
help="describe,eval,ingest,predict,preprocess,train",
|
18 |
)
|
19 |
parser.add_argument(
|
20 |
"--objects",
|
@@ -60,6 +60,12 @@ parser.add_argument(
|
|
60 |
type=str,
|
61 |
default="",
|
62 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
parser.add_argument(
|
64 |
"--model_path",
|
65 |
type=str,
|
@@ -111,6 +117,29 @@ elif args.task == "predict":
|
|
111 |
test_labels,
|
112 |
args.output_path,
|
113 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
elif args.task == "preprocess":
|
115 |
success = preprocess(
|
116 |
args.output_path,
|
|
|
14 |
"task",
|
15 |
type=str,
|
16 |
default="",
|
17 |
+
help="describe,eval,ingest,predict,predict_image,preprocess,train",
|
18 |
)
|
19 |
parser.add_argument(
|
20 |
"--objects",
|
|
|
60 |
type=str,
|
61 |
default="",
|
62 |
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--is_url",
|
65 |
+
type=int,
|
66 |
+
default=0,
|
67 |
+
help="0/1",
|
68 |
+
)
|
69 |
parser.add_argument(
|
70 |
"--model_path",
|
71 |
type=str,
|
|
|
117 |
test_labels,
|
118 |
args.output_path,
|
119 |
)
|
120 |
+
elif args.task == "predict_image":
|
121 |
+
success = True
|
122 |
+
|
123 |
+
classifier = Image_Classifier()
|
124 |
+
|
125 |
+
success = classifier.load(args.model_path)
|
126 |
+
|
127 |
+
if success:
|
128 |
+
if args.is_url:
|
129 |
+
image_filename = file.auxiliary("image", file.extension(args.data_path))
|
130 |
+
if not file.download(args.data_path, image_filename):
|
131 |
+
success = False
|
132 |
+
else:
|
133 |
+
image_filename = args.data_path
|
134 |
+
|
135 |
+
if success:
|
136 |
+
success, image = file.load_image(image_filename)
|
137 |
+
|
138 |
+
if success:
|
139 |
+
success = classifier.predict(
|
140 |
+
image / 255.0,
|
141 |
+
output_path=args.output_path,
|
142 |
+
)
|
143 |
elif args.task == "preprocess":
|
144 |
success = preprocess(
|
145 |
args.output_path,
|