import argparse |
import numpy as np |
import sys |
import cv2 |
import tritonclient.grpc as grpcclient |
from tritonclient.utils import InferenceServerException |
from processing import preprocess, postprocess |
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS |
from labels import COCOLabels |
INPUT_NAMES = ["images"] |
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"] |
if __name__ == '__main__': |
parser = argparse.ArgumentParser() |
parser.add_argument('mode', |
choices=['dummy', 'image', 'video'], |
default='dummy', |
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.') |
parser.add_argument('input', |
type=str, |
nargs='?', |
help='Input file to load from in image or video mode') |
parser.add_argument('-m', |
'--model', |
type=str, |
required=False, |
default='yolov7', |
help='Inference model name, default yolov7') |
parser.add_argument('--width', |
type=int, |
required=False, |
default=640, |
help='Inference model input width, default 640') |
parser.add_argument('--height', |
type=int, |
required=False, |
default=640, |
help='Inference model input height, default 640') |
parser.add_argument('-u', |
'--url', |
type=str, |
required=False, |
default='localhost:8001', |
help='Inference server URL, default localhost:8001') |
parser.add_argument('-o', |
'--out', |
type=str, |
required=False, |
default='', |
help='Write output into file instead of displaying it') |
parser.add_argument('-f', |
'--fps', |
type=float, |
required=False, |
default=24.0, |
help='Video output fps, default 24.0 FPS') |
parser.add_argument('-i', |
'--model-info', |
action="store_true", |
required=False, |
default=False, |
help='Print model status, configuration and statistics') |
parser.add_argument('-v', |
'--verbose', |
action="store_true", |
required=False, |
default=False, |
help='Enable verbose client output') |
parser.add_argument('-t', |
'--client-timeout', |
type=float, |
required=False, |
default=None, |
help='Client timeout in seconds, default no timeout') |
parser.add_argument('-s', |
'--ssl', |
action="store_true", |
required=False, |
default=False, |
help='Enable SSL encrypted channel to the server') |
parser.add_argument('-r', |
'--root-certificates', |
type=str, |
required=False, |
default=None, |
help='File holding PEM-encoded root certificates, default none') |
parser.add_argument('-p', |
'--private-key', |
type=str, |
required=False, |
default=None, |
help='File holding PEM-encoded private key, default is none') |
parser.add_argument('-x', |
'--certificate-chain', |
type=str, |
required=False, |
default=None, |
help='File holding PEM-encoded certicate chain default is none') |
FLAGS = parser.parse_args() |
try: |
triton_client = grpcclient.InferenceServerClient( |
url=FLAGS.url, |
verbose=FLAGS.verbose, |
ssl=FLAGS.ssl, |
root_certificates=FLAGS.root_certificates, |
private_key=FLAGS.private_key, |
certificate_chain=FLAGS.certificate_chain) |
except Exception as e: |
print("context creation failed: " + str(e)) |
sys.exit() |
if not triton_client.is_server_live(): |
print("FAILED : is_server_live") |
sys.exit(1) |
if not triton_client.is_server_ready(): |
print("FAILED : is_server_ready") |
sys.exit(1) |
if not triton_client.is_model_ready(FLAGS.model): |
print("FAILED : is_model_ready") |
sys.exit(1) |
if FLAGS.model_info: |
try: |
metadata = triton_client.get_model_metadata(FLAGS.model) |
print(metadata) |
except InferenceServerException as ex: |
if "Request for unknown model" not in ex.message(): |
print("FAILED : get_model_metadata") |
print("Got: {}".format(ex.message())) |
sys.exit(1) |
else: |
print("FAILED : get_model_metadata") |
sys.exit(1) |
try: |
config = triton_client.get_model_config(FLAGS.model) |
if not (config.config.name == FLAGS.model): |
print("FAILED: get_model_config") |
sys.exit(1) |
print(config) |
except InferenceServerException as ex: |
print("FAILED : get_model_config") |
print("Got: {}".format(ex.message())) |
sys.exit(1) |
if FLAGS.mode == 'dummy': |
print("Running in 'dummy' mode") |
print("Creating emtpy buffer filled with ones...") |
inputs = [] |
outputs = [] |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) |
inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32)) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) |
print("Invoking inference...") |
results = triton_client.infer(model_name=FLAGS.model, |
inputs=inputs, |
outputs=outputs, |
client_timeout=FLAGS.client_timeout) |
if FLAGS.model_info: |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) |
if len(statistics.model_stats) != 1: |
print("FAILED: get_inference_statistics") |
sys.exit(1) |
print(statistics) |
print("Done") |
for output in OUTPUT_NAMES: |
result = results.as_numpy(output) |
print(f"Received result buffer \"{output}\" of size {result.shape}") |
print(f"Naive buffer sum: {np.sum(result)}") |
if FLAGS.mode == 'image': |
print("Running in 'image' mode") |
if not FLAGS.input: |
print("FAILED: no input image") |
sys.exit(1) |
inputs = [] |
outputs = [] |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) |
print("Creating buffer from image file...") |
input_image = cv2.imread(str(FLAGS.input)) |
if input_image is None: |
print(f"FAILED: could not load input image {str(FLAGS.input)}") |
sys.exit(1) |
input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height]) |
input_image_buffer = np.expand_dims(input_image_buffer, axis=0) |
inputs[0].set_data_from_numpy(input_image_buffer) |
print("Invoking inference...") |
results = triton_client.infer(model_name=FLAGS.model, |
inputs=inputs, |
outputs=outputs, |
client_timeout=FLAGS.client_timeout) |
if FLAGS.model_info: |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) |
if len(statistics.model_stats) != 1: |
print("FAILED: get_inference_statistics") |
sys.exit(1) |
print(statistics) |
print("Done") |
for output in OUTPUT_NAMES: |
result = results.as_numpy(output) |
print(f"Received result buffer \"{output}\" of size {result.shape}") |
print(f"Naive buffer sum: {np.sum(result)}") |
num_dets = results.as_numpy(OUTPUT_NAMES[0]) |
det_boxes = results.as_numpy(OUTPUT_NAMES[1]) |
det_scores = results.as_numpy(OUTPUT_NAMES[2]) |
det_classes = results.as_numpy(OUTPUT_NAMES[3]) |
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height]) |
print(f"Detected objects: {len(detected_objects)}") |
for box in detected_objects: |
print(f"{COCOLabels(box.classID).name}: {box.confidence}") |
input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) |
size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) |
input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) |
input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) |
if FLAGS.out: |
cv2.imwrite(FLAGS.out, input_image) |
print(f"Saved result to {FLAGS.out}") |
else: |
cv2.imshow('image', input_image) |
cv2.waitKey(0) |
cv2.destroyAllWindows() |
if FLAGS.mode == 'video': |
print("Running in 'video' mode") |
if not FLAGS.input: |
print("FAILED: no input video") |
sys.exit(1) |
inputs = [] |
outputs = [] |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) |
print("Opening input video stream...") |
cap = cv2.VideoCapture(FLAGS.input) |
if not cap.isOpened(): |
print(f"FAILED: cannot open video {FLAGS.input}") |
sys.exit(1) |
counter = 0 |
out = None |
print("Invoking inference...") |
while True: |
ret, frame = cap.read() |
if not ret: |
print("failed to fetch next frame") |
break |
if counter == 0 and FLAGS.out: |
print("Opening output video stream...") |
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') |
out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0])) |
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height]) |
input_image_buffer = np.expand_dims(input_image_buffer, axis=0) |
inputs[0].set_data_from_numpy(input_image_buffer) |
results = triton_client.infer(model_name=FLAGS.model, |
inputs=inputs, |
outputs=outputs, |
client_timeout=FLAGS.client_timeout) |
num_dets = results.as_numpy("num_dets") |
det_boxes = results.as_numpy("det_boxes") |
det_scores = results.as_numpy("det_scores") |
det_classes = results.as_numpy("det_classes") |
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height]) |
print(f"Frame {counter}: {len(detected_objects)} objects") |
counter += 1 |
for box in detected_objects: |
print(f"{COCOLabels(box.classID).name}: {box.confidence}") |
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) |
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) |
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) |
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) |
if FLAGS.out: |
out.write(frame) |
else: |
cv2.imshow('image', frame) |
if cv2.waitKey(1) == ord('q'): |
break |
if FLAGS.model_info: |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) |
if len(statistics.model_stats) != 1: |
print("FAILED: get_inference_statistics") |
sys.exit(1) |
print(statistics) |
print("Done") |
cap.release() |
if FLAGS.out: |
out.release() |
else: |
cv2.destroyAllWindows() |