Update eval_onnx.py
#2
by
hangyang-amd
- opened
- eval_onnx.py +2 -1
eval_onnx.py
CHANGED
@@ -34,6 +34,7 @@ parser.add_argument(
|
|
34 |
default="vaip_config.json",
|
35 |
help="Path of the config file for seting provider_options.",
|
36 |
)
|
|
|
37 |
args = parser.parse_args()
|
38 |
|
39 |
class AverageMeter(object):
|
@@ -144,7 +145,7 @@ def val_imagenet():
|
|
144 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
145 |
with torch.no_grad():
|
146 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
147 |
-
inputs, targets = images.numpy(), targets
|
148 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
149 |
|
150 |
outputs = ort_session.run(None, ort_inputs)
|
|
|
34 |
default="vaip_config.json",
|
35 |
help="Path of the config file for seting provider_options.",
|
36 |
)
|
37 |
+
parser.add_argument('--data_format', type=str, choices=["nchw", "nhwc"], default="nchw")
|
38 |
args = parser.parse_args()
|
39 |
|
40 |
class AverageMeter(object):
|
|
|
145 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
146 |
with torch.no_grad():
|
147 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
148 |
+
inputs, targets = images.numpy() if args.data_format == "nchw" else images.permute((0, 2, 3, 1)).numpy(), targets
|
149 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
150 |
|
151 |
outputs = ort_session.run(None, ort_inputs)
|