kamangir commited on
Commit
d2b2be7
1 Parent(s): cd2efd1

validating single image predict for fashion_mnist - kamangir/bolt#692

Browse files
abcli/image_classifier.sh CHANGED
@@ -68,12 +68,26 @@ function abcli_image_classifier_predict() {
68
  local data_source=$(abcli_option "$options" "data" object)
69
  local model_source=$(abcli_option "$options" "model" saved)
70
 
71
- local model_path=$(abcli_huggingface get_model_path image-classifier "$model_name" "$options")
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  if [ "$data_source" == "object" ] ; then
74
  abcli_download object $data_object
75
  fi
76
 
 
 
 
77
  if [ "$model_source" == "object" ] ; then
78
  local model_object=TBD
79
  abcli_download object $model_object
@@ -81,7 +95,7 @@ function abcli_image_classifier_predict() {
81
 
82
  abcli_log "image_classifier($model_path).predict($data_object): $options"
83
 
84
- if [ ! -f "$abcli_object_root/$data_object/test_images.pyndarray" ] ; then
85
  python3 -m image_classifier \
86
  preprocess \
87
  --infer_annotation 0 \
@@ -92,17 +106,34 @@ function abcli_image_classifier_predict() {
92
  ${@:4}
93
  fi
94
 
95
- cp -v $abcli_object_root/$data_object/*.pyndarray .
96
- cp -v $model_path/image_classifier/model/class_names.json .
97
 
98
- python3 -m image_classifier \
99
- predict \
100
- --data_path $abcli_object_root/$data_object \
101
- --model_path $model_path \
102
- --output_path $abcli_object_path \
103
- ${@:4}
 
 
 
 
104
 
105
- abcli_tag set . image_classifier,predict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  }
107
 
108
  function abcli_image_classifier_train() {
 
68
  local data_source=$(abcli_option "$options" "data" object)
69
  local model_source=$(abcli_option "$options" "model" saved)
70
 
71
+ if [ "$(abcli_keyword_is $data_object validate)" == true ] ; then
72
+ if [ "$data_source" == "object" ] ; then
73
+ abcli_log_error "-imge-classifier: predict: validation object not found."
74
+ return
75
+ fi
76
+
77
+ if [ "$data_source" == "url" ] ; then
78
+ local data_object="https://upload.wikimedia.org/wikipedia/commons/4/45/Red_High_Heel_Pumps.jpg"
79
+ else
80
+ local data_object="a-validation-filename"
81
+ fi
82
+ fi
83
 
84
  if [ "$data_source" == "object" ] ; then
85
  abcli_download object $data_object
86
  fi
87
 
88
+
89
+ local model_path=$(abcli_huggingface get_model_path image-classifier "$model_name" "$options")
90
+
91
  if [ "$model_source" == "object" ] ; then
92
  local model_object=TBD
93
  abcli_download object $model_object
 
95
 
96
  abcli_log "image_classifier($model_path).predict($data_object): $options"
97
 
98
+ if [ ! -f "$abcli_object_root/$data_object/test_images.pyndarray" ] && [ "$data_source" == "object" ] ; then
99
  python3 -m image_classifier \
100
  preprocess \
101
  --infer_annotation 0 \
 
106
  ${@:4}
107
  fi
108
 
109
+ return
 
110
 
111
+ if [ "$data_source" == "object" ] ; then
112
+ cp -v $abcli_object_root/$data_object/*.pyndarray .
113
+ cp -v $model_path/image_classifier/model/class_names.json .
114
+
115
+ python3 -m image_classifier \
116
+ predict \
117
+ --data_path $abcli_object_root/$data_object \
118
+ --model_path $model_path \
119
+ --output_path $abcli_object_path \
120
+ ${@:4}
121
 
122
+ abcli_tag set . image_classifier,predict
123
+ else
124
+ local is_url=0
125
+ if [ "$data_source" == "url" ] ; then
126
+ local is_url=1
127
+ fi
128
+
129
+ python3 -m image_classifier \
130
+ predict_image \
131
+ --data_path $data_object \
132
+ --is_url $is_url \
133
+ --model_path $model_path \
134
+ --output_path $abcli_object_path \
135
+ ${@:4}
136
+ fi
137
  }
138
 
139
  function abcli_image_classifier_train() {
image_classifier/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  name = "image_classifier"
2
 
3
- version = "1.1.144"
4
 
5
  description = "fashion-mnist + hugging-face + awesome-bash-cli"
 
1
  name = "image_classifier"
2
 
3
+ version = "1.1.145"
4
 
5
  description = "fashion-mnist + hugging-face + awesome-bash-cli"
image_classifier/__main__.py CHANGED
@@ -14,7 +14,7 @@ parser.add_argument(
14
  "task",
15
  type=str,
16
  default="",
17
- help="describe,eval,ingest,predict,preprocess,train",
18
  )
19
  parser.add_argument(
20
  "--objects",
@@ -60,6 +60,12 @@ parser.add_argument(
60
  type=str,
61
  default="",
62
  )
 
 
 
 
 
 
63
  parser.add_argument(
64
  "--model_path",
65
  type=str,
@@ -111,6 +117,29 @@ elif args.task == "predict":
111
  test_labels,
112
  args.output_path,
113
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  elif args.task == "preprocess":
115
  success = preprocess(
116
  args.output_path,
 
14
  "task",
15
  type=str,
16
  default="",
17
+ help="describe,eval,ingest,predict,predict_image,preprocess,train",
18
  )
19
  parser.add_argument(
20
  "--objects",
 
60
  type=str,
61
  default="",
62
  )
63
+ parser.add_argument(
64
+ "--is_url",
65
+ type=int,
66
+ default=0,
67
+ help="0/1",
68
+ )
69
  parser.add_argument(
70
  "--model_path",
71
  type=str,
 
117
  test_labels,
118
  args.output_path,
119
  )
120
+ elif args.task == "predict_image":
121
+ success = True
122
+
123
+ classifier = Image_Classifier()
124
+
125
+ success = classifier.load(args.model_path)
126
+
127
+ if success:
128
+ if args.is_url:
129
+ image_filename = file.auxiliary("image", file.extension(args.data_path))
130
+ if not file.download(args.data_path, image_filename):
131
+ success = False
132
+ else:
133
+ image_filename = args.data_path
134
+
135
+ if success:
136
+ success, image = file.load_image(image_filename)
137
+
138
+ if success:
139
+ success = classifier.predict(
140
+ image / 255.0,
141
+ output_path=args.output_path,
142
+ )
143
  elif args.task == "preprocess":
144
  success = preprocess(
145
  args.output_path,