Spaces:
Runtime error
Runtime error
File size: 5,676 Bytes
80df0b9 9970a74 80df0b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import argparse
import os
import cv2
import numpy as np
from loguru import logger
import onnxruntime
from yolox.data.data_augment import preproc as preprocess
from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer
def make_parser():
parser = argparse.ArgumentParser("onnxruntime inference sample")
parser.add_argument(
"-m",
"--model",
type=str,
default="bytetrack_s.onnx",
help="Input your onnx model.",
)
parser.add_argument(
"-i",
"--video_path",
type=str,
default='../../videos/palace.mp4',
help="Path to your input image.",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default='.',
help="Path to your output directory.",
)
parser.add_argument(
"-s",
"--score_thr",
type=float,
default=0.1,
help="Score threshould to filter the result.",
)
parser.add_argument(
"-n",
"--nms_thr",
type=float,
default=0.7,
help="NMS threshould.",
)
parser.add_argument(
"--input_shape",
type=str,
default="608,1088",
help="Specify an input shape for inference.",
)
parser.add_argument(
"--with_p6",
action="store_true",
help="Whether your model uses p6 in FPN/PAN.",
)
# tracking args
parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking")
parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
return parser
class Predictor(object):
def __init__(self, args):
self.rgb_means = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
self.args = args
self.session = onnxruntime.InferenceSession(args.model)
self.input_shape = tuple(map(int, args.input_shape.split(',')))
def inference(self, ori_img, timer):
img_info = {"id": 0}
height, width = ori_img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = ori_img
img, ratio = preprocess(ori_img, self.input_shape, self.rgb_means, self.std)
img_info["ratio"] = ratio
ort_inputs = {self.session.get_inputs()[0].name: img[None, :, :, :]}
timer.tic()
output = self.session.run(None, ort_inputs)
predictions = demo_postprocess(output[0], self.input_shape, p6=self.args.with_p6)[0]
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
boxes_xyxy /= ratio
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.args.nms_thr, score_thr=self.args.score_thr)
return dets[:, :-1], img_info
def imageflow_demo(predictor, args):
cap = cv2.VideoCapture(args.video_path)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
fps = cap.get(cv2.CAP_PROP_FPS)
save_folder = args.output_dir
os.makedirs(save_folder, exist_ok=True)
save_path = os.path.join(save_folder, args.video_path.split("/")[-1])
logger.info(f"video save_path is {save_path}")
vid_writer = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)
tracker = BYTETracker(args, frame_rate=30)
timer = Timer()
frame_id = 0
results = []
while True:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame, timer)
online_targets = tracker.update(outputs, [img_info['height'], img_info['width']], [img_info['height'], img_info['width']])
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(t.score)
timer.toc()
results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
fps=1. / timer.average_time)
vid_writer.write(online_im)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
frame_id += 1
if __name__ == '__main__':
args = make_parser().parse_args()
predictor = Predictor(args)
imageflow_demo(predictor, args) |