Spaces:
Running
Running
#!/usr/bin/env python | |
import argparse | |
import numpy as np | |
import sys | |
import cv2 | |
import tritonclient.grpc as grpcclient | |
from tritonclient.utils import InferenceServerException | |
from processing import preprocess, postprocess | |
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS | |
from labels import COCOLabels | |
INPUT_NAMES = ["images"] | |
OUTPUT_NAMES = ["num_dets", "det_boxes", "det_scores", "det_classes"] | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('mode', | |
choices=['dummy', 'image', 'video'], | |
default='dummy', | |
help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.') | |
parser.add_argument('input', | |
type=str, | |
nargs='?', | |
help='Input file to load from in image or video mode') | |
parser.add_argument('-m', | |
'--model', | |
type=str, | |
required=False, | |
default='yolov7', | |
help='Inference model name, default yolov7') | |
parser.add_argument('--width', | |
type=int, | |
required=False, | |
default=640, | |
help='Inference model input width, default 640') | |
parser.add_argument('--height', | |
type=int, | |
required=False, | |
default=640, | |
help='Inference model input height, default 640') | |
parser.add_argument('-u', | |
'--url', | |
type=str, | |
required=False, | |
default='localhost:8001', | |
help='Inference server URL, default localhost:8001') | |
parser.add_argument('-o', | |
'--out', | |
type=str, | |
required=False, | |
default='', | |
help='Write output into file instead of displaying it') | |
parser.add_argument('-f', | |
'--fps', | |
type=float, | |
required=False, | |
default=24.0, | |
help='Video output fps, default 24.0 FPS') | |
parser.add_argument('-i', | |
'--model-info', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Print model status, configuration and statistics') | |
parser.add_argument('-v', | |
'--verbose', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Enable verbose client output') | |
parser.add_argument('-t', | |
'--client-timeout', | |
type=float, | |
required=False, | |
default=None, | |
help='Client timeout in seconds, default no timeout') | |
parser.add_argument('-s', | |
'--ssl', | |
action="store_true", | |
required=False, | |
default=False, | |
help='Enable SSL encrypted channel to the server') | |
parser.add_argument('-r', | |
'--root-certificates', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded root certificates, default none') | |
parser.add_argument('-p', | |
'--private-key', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded private key, default is none') | |
parser.add_argument('-x', | |
'--certificate-chain', | |
type=str, | |
required=False, | |
default=None, | |
help='File holding PEM-encoded certicate chain default is none') | |
FLAGS = parser.parse_args() | |
# Create server context | |
try: | |
triton_client = grpcclient.InferenceServerClient( | |
url=FLAGS.url, | |
verbose=FLAGS.verbose, | |
ssl=FLAGS.ssl, | |
root_certificates=FLAGS.root_certificates, | |
private_key=FLAGS.private_key, | |
certificate_chain=FLAGS.certificate_chain) | |
except Exception as e: | |
print("context creation failed: " + str(e)) | |
sys.exit() | |
# Health check | |
if not triton_client.is_server_live(): | |
print("FAILED : is_server_live") | |
sys.exit(1) | |
if not triton_client.is_server_ready(): | |
print("FAILED : is_server_ready") | |
sys.exit(1) | |
if not triton_client.is_model_ready(FLAGS.model): | |
print("FAILED : is_model_ready") | |
sys.exit(1) | |
if FLAGS.model_info: | |
# Model metadata | |
try: | |
metadata = triton_client.get_model_metadata(FLAGS.model) | |
print(metadata) | |
except InferenceServerException as ex: | |
if "Request for unknown model" not in ex.message(): | |
print("FAILED : get_model_metadata") | |
print("Got: {}".format(ex.message())) | |
sys.exit(1) | |
else: | |
print("FAILED : get_model_metadata") | |
sys.exit(1) | |
# Model configuration | |
try: | |
config = triton_client.get_model_config(FLAGS.model) | |
if not (config.config.name == FLAGS.model): | |
print("FAILED: get_model_config") | |
sys.exit(1) | |
print(config) | |
except InferenceServerException as ex: | |
print("FAILED : get_model_config") | |
print("Got: {}".format(ex.message())) | |
sys.exit(1) | |
# DUMMY MODE | |
if FLAGS.mode == 'dummy': | |
print("Running in 'dummy' mode") | |
print("Creating emtpy buffer filled with ones...") | |
inputs = [] | |
outputs = [] | |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) | |
inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32)) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) | |
print("Invoking inference...") | |
results = triton_client.infer(model_name=FLAGS.model, | |
inputs=inputs, | |
outputs=outputs, | |
client_timeout=FLAGS.client_timeout) | |
if FLAGS.model_info: | |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: get_inference_statistics") | |
sys.exit(1) | |
print(statistics) | |
print("Done") | |
for output in OUTPUT_NAMES: | |
result = results.as_numpy(output) | |
print(f"Received result buffer \"{output}\" of size {result.shape}") | |
print(f"Naive buffer sum: {np.sum(result)}") | |
# IMAGE MODE | |
if FLAGS.mode == 'image': | |
print("Running in 'image' mode") | |
if not FLAGS.input: | |
print("FAILED: no input image") | |
sys.exit(1) | |
inputs = [] | |
outputs = [] | |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) | |
print("Creating buffer from image file...") | |
input_image = cv2.imread(str(FLAGS.input)) | |
if input_image is None: | |
print(f"FAILED: could not load input image {str(FLAGS.input)}") | |
sys.exit(1) | |
input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height]) | |
input_image_buffer = np.expand_dims(input_image_buffer, axis=0) | |
inputs[0].set_data_from_numpy(input_image_buffer) | |
print("Invoking inference...") | |
results = triton_client.infer(model_name=FLAGS.model, | |
inputs=inputs, | |
outputs=outputs, | |
client_timeout=FLAGS.client_timeout) | |
if FLAGS.model_info: | |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: get_inference_statistics") | |
sys.exit(1) | |
print(statistics) | |
print("Done") | |
for output in OUTPUT_NAMES: | |
result = results.as_numpy(output) | |
print(f"Received result buffer \"{output}\" of size {result.shape}") | |
print(f"Naive buffer sum: {np.sum(result)}") | |
num_dets = results.as_numpy(OUTPUT_NAMES[0]) | |
det_boxes = results.as_numpy(OUTPUT_NAMES[1]) | |
det_scores = results.as_numpy(OUTPUT_NAMES[2]) | |
det_classes = results.as_numpy(OUTPUT_NAMES[3]) | |
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height]) | |
print(f"Detected objects: {len(detected_objects)}") | |
for box in detected_objects: | |
print(f"{COCOLabels(box.classID).name}: {box.confidence}") | |
input_image = render_box(input_image, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) | |
size = get_text_size(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) | |
input_image = render_filled_box(input_image, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) | |
input_image = render_text(input_image, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) | |
if FLAGS.out: | |
cv2.imwrite(FLAGS.out, input_image) | |
print(f"Saved result to {FLAGS.out}") | |
else: | |
cv2.imshow('image', input_image) | |
cv2.waitKey(0) | |
cv2.destroyAllWindows() | |
# VIDEO MODE | |
if FLAGS.mode == 'video': | |
print("Running in 'video' mode") | |
if not FLAGS.input: | |
print("FAILED: no input video") | |
sys.exit(1) | |
inputs = [] | |
outputs = [] | |
inputs.append(grpcclient.InferInput(INPUT_NAMES[0], [1, 3, FLAGS.width, FLAGS.height], "FP32")) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[0])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[1])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[2])) | |
outputs.append(grpcclient.InferRequestedOutput(OUTPUT_NAMES[3])) | |
print("Opening input video stream...") | |
cap = cv2.VideoCapture(FLAGS.input) | |
if not cap.isOpened(): | |
print(f"FAILED: cannot open video {FLAGS.input}") | |
sys.exit(1) | |
counter = 0 | |
out = None | |
print("Invoking inference...") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
print("failed to fetch next frame") | |
break | |
if counter == 0 and FLAGS.out: | |
print("Opening output video stream...") | |
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V') | |
out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0])) | |
input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height]) | |
input_image_buffer = np.expand_dims(input_image_buffer, axis=0) | |
inputs[0].set_data_from_numpy(input_image_buffer) | |
results = triton_client.infer(model_name=FLAGS.model, | |
inputs=inputs, | |
outputs=outputs, | |
client_timeout=FLAGS.client_timeout) | |
num_dets = results.as_numpy("num_dets") | |
det_boxes = results.as_numpy("det_boxes") | |
det_scores = results.as_numpy("det_scores") | |
det_classes = results.as_numpy("det_classes") | |
detected_objects = postprocess(num_dets, det_boxes, det_scores, det_classes, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height]) | |
print(f"Frame {counter}: {len(detected_objects)} objects") | |
counter += 1 | |
for box in detected_objects: | |
print(f"{COCOLabels(box.classID).name}: {box.confidence}") | |
frame = render_box(frame, box.box(), color=tuple(RAND_COLORS[box.classID % 64].tolist())) | |
size = get_text_size(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", normalised_scaling=0.6) | |
frame = render_filled_box(frame, (box.x1 - 3, box.y1 - 3, box.x1 + size[0], box.y1 + size[1]), color=(220, 220, 220)) | |
frame = render_text(frame, f"{COCOLabels(box.classID).name}: {box.confidence:.2f}", (box.x1, box.y1), color=(30, 30, 30), normalised_scaling=0.5) | |
if FLAGS.out: | |
out.write(frame) | |
else: | |
cv2.imshow('image', frame) | |
if cv2.waitKey(1) == ord('q'): | |
break | |
if FLAGS.model_info: | |
statistics = triton_client.get_inference_statistics(model_name=FLAGS.model) | |
if len(statistics.model_stats) != 1: | |
print("FAILED: get_inference_statistics") | |
sys.exit(1) | |
print(statistics) | |
print("Done") | |
cap.release() | |
if FLAGS.out: | |
out.release() | |
else: | |
cv2.destroyAllWindows() | |