Spaces:
Runtime error
Runtime error
from loguru import logger | |
import cv2 | |
import torch | |
from yolox.data.data_augment import preproc | |
from yolox.exp import get_exp | |
from yolox.utils import fuse_model, get_model_info, postprocess, vis | |
from yolox.utils.visualize import plot_tracking | |
from yolox.tracker.byte_tracker import BYTETracker | |
from yolox.tracking_utils.timer import Timer | |
import argparse | |
import os | |
import time | |
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] | |
def make_parser(): | |
parser = argparse.ArgumentParser("ByteTrack Demo!") | |
parser.add_argument( | |
"demo", default="image", help="demo type, eg. image, video and webcam" | |
) | |
parser.add_argument("-expn", "--experiment-name", type=str, default=None) | |
parser.add_argument("-n", "--name", type=str, default=None, help="model name") | |
parser.add_argument( | |
#"--path", default="./datasets/mot/train/MOT17-05-FRCNN/img1", help="path to images or video" | |
"--path", default="./videos/palace.mp4", help="path to images or video" | |
) | |
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id") | |
parser.add_argument( | |
"--save_result", | |
action="store_true", | |
help="whether to save the inference result of image/video", | |
) | |
# exp file | |
parser.add_argument( | |
"-f", | |
"--exp_file", | |
default=None, | |
type=str, | |
help="pls input your expriment description file", | |
) | |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") | |
parser.add_argument( | |
"--device", | |
default="gpu", | |
type=str, | |
help="device to run our model, can either be cpu or gpu", | |
) | |
parser.add_argument("--conf", default=None, type=float, help="test conf") | |
parser.add_argument("--nms", default=None, type=float, help="test nms threshold") | |
parser.add_argument("--tsize", default=None, type=int, help="test img size") | |
parser.add_argument( | |
"--fp16", | |
dest="fp16", | |
default=False, | |
action="store_true", | |
help="Adopting mix precision evaluating.", | |
) | |
parser.add_argument( | |
"--fuse", | |
dest="fuse", | |
default=False, | |
action="store_true", | |
help="Fuse conv and bn for testing.", | |
) | |
parser.add_argument( | |
"--trt", | |
dest="trt", | |
default=False, | |
action="store_true", | |
help="Using TensorRT model for testing.", | |
) | |
# tracking args | |
parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold") | |
parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks") | |
parser.add_argument("--match_thresh", type=int, default=0.8, help="matching threshold for tracking") | |
parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes') | |
parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.") | |
return parser | |
def get_image_list(path): | |
image_names = [] | |
for maindir, subdir, file_name_list in os.walk(path): | |
for filename in file_name_list: | |
apath = os.path.join(maindir, filename) | |
ext = os.path.splitext(apath)[1] | |
if ext in IMAGE_EXT: | |
image_names.append(apath) | |
return image_names | |
def write_results(filename, results): | |
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n' | |
with open(filename, 'w') as f: | |
for frame_id, tlwhs, track_ids, scores in results: | |
for tlwh, track_id, score in zip(tlwhs, track_ids, scores): | |
if track_id < 0: | |
continue | |
x1, y1, w, h = tlwh | |
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2)) | |
f.write(line) | |
logger.info('save results to {}'.format(filename)) | |
class Predictor(object): | |
def __init__( | |
self, | |
model, | |
exp, | |
trt_file=None, | |
decoder=None, | |
device="cpu", | |
fp16=False | |
): | |
self.model = model | |
self.decoder = decoder | |
self.num_classes = exp.num_classes | |
self.confthre = exp.test_conf | |
self.nmsthre = exp.nmsthre | |
self.test_size = exp.test_size | |
self.device = device | |
self.fp16 = fp16 | |
if trt_file is not None: | |
from torch2trt import TRTModule | |
model_trt = TRTModule() | |
model_trt.load_state_dict(torch.load(trt_file)) | |
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() | |
self.model(x) | |
self.model = model_trt | |
self.rgb_means = (0.485, 0.456, 0.406) | |
self.std = (0.229, 0.224, 0.225) | |
def inference(self, img, timer): | |
img_info = {"id": 0} | |
if isinstance(img, str): | |
img_info["file_name"] = os.path.basename(img) | |
img = cv2.imread(img) | |
else: | |
img_info["file_name"] = None | |
height, width = img.shape[:2] | |
img_info["height"] = height | |
img_info["width"] = width | |
img_info["raw_img"] = img | |
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std) | |
img_info["ratio"] = ratio | |
img = torch.from_numpy(img).unsqueeze(0) | |
img = img.float() | |
if self.device == "gpu": | |
img = img.cuda() | |
if self.fp16: | |
img = img.half() # to FP16 | |
with torch.no_grad(): | |
timer.tic() | |
outputs = self.model(img) | |
if self.decoder is not None: | |
outputs = self.decoder(outputs, dtype=outputs.type()) | |
outputs = postprocess( | |
outputs, self.num_classes, self.confthre, self.nmsthre | |
) | |
#logger.info("Infer time: {:.4f}s".format(time.time() - t0)) | |
return outputs, img_info | |
def image_demo(predictor, vis_folder, path, current_time, save_result): | |
if os.path.isdir(path): | |
files = get_image_list(path) | |
else: | |
files = [path] | |
files.sort() | |
tracker = BYTETracker(args, frame_rate=30) | |
timer = Timer() | |
frame_id = 0 | |
results = [] | |
for image_name in files: | |
if frame_id % 20 == 0: | |
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time))) | |
outputs, img_info = predictor.inference(image_name, timer) | |
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size) | |
online_tlwhs = [] | |
online_ids = [] | |
online_scores = [] | |
for t in online_targets: | |
tlwh = t.tlwh | |
tid = t.track_id | |
vertical = tlwh[2] / tlwh[3] > 1.6 | |
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical: | |
online_tlwhs.append(tlwh) | |
online_ids.append(tid) | |
online_scores.append(t.score) | |
timer.toc() | |
# save results | |
results.append((frame_id + 1, online_tlwhs, online_ids, online_scores)) | |
online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, | |
fps=1. / timer.average_time) | |
#result_image = predictor.visual(outputs[0], img_info, predictor.confthre) | |
if save_result: | |
save_file_name = "out.jpg" | |
cv2.imwrite(save_file_name, online_im) | |
ch = cv2.waitKey(0) | |
frame_id += 1 | |
if ch == 27 or ch == ord("q") or ch == ord("Q"): | |
break | |
#write_results(result_filename, results) | |
def imageflow_demo(predictor, vis_folder, current_time, args): | |
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid) | |
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float | |
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
save_folder = os.path.join( | |
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time) | |
) | |
os.makedirs(save_folder, exist_ok=True) | |
if args.demo == "video": | |
save_path = os.path.join(save_folder, args.path.split("/")[-1]) | |
else: | |
save_path = os.path.join(save_folder, "camera.mp4") | |
logger.info(f"video save_path is {save_path}") | |
vid_writer = cv2.VideoWriter( | |
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)) | |
) | |
tracker = BYTETracker(args, frame_rate=30) | |
timer = Timer() | |
frame_id = 0 | |
results = [] | |
while True: | |
if frame_id % 20 == 0: | |
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time))) | |
ret_val, frame = cap.read() | |
if ret_val: | |
outputs, img_info = predictor.inference(frame, timer) | |
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size) | |
online_tlwhs = [] | |
online_ids = [] | |
online_scores = [] | |
for t in online_targets: | |
tlwh = t.tlwh | |
tid = t.track_id | |
vertical = tlwh[2] / tlwh[3] > 1.6 | |
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical: | |
online_tlwhs.append(tlwh) | |
online_ids.append(tid) | |
online_scores.append(t.score) | |
timer.toc() | |
results.append((frame_id + 1, online_tlwhs, online_ids, online_scores)) | |
online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, | |
fps=1. / timer.average_time) | |
if args.save_result: | |
vid_writer.write(online_im) | |
ch = cv2.waitKey(1) | |
if ch == 27 or ch == ord("q") or ch == ord("Q"): | |
break | |
else: | |
break | |
frame_id += 1 | |
def main(exp, args): | |
if not args.experiment_name: | |
args.experiment_name = exp.exp_name | |
file_name = os.path.join(exp.output_dir, args.experiment_name) | |
os.makedirs(file_name, exist_ok=True) | |
if args.save_result: | |
vis_folder = os.path.join(file_name, "track_vis") | |
os.makedirs(vis_folder, exist_ok=True) | |
if args.trt: | |
args.device = "gpu" | |
logger.info("Args: {}".format(args)) | |
if args.conf is not None: | |
exp.test_conf = args.conf | |
if args.nms is not None: | |
exp.nmsthre = args.nms | |
if args.tsize is not None: | |
exp.test_size = (args.tsize, args.tsize) | |
model = exp.get_model() | |
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size))) | |
if args.device == "gpu": | |
model.cuda() | |
model.eval() | |
if not args.trt: | |
if args.ckpt is None: | |
ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar") | |
else: | |
ckpt_file = args.ckpt | |
logger.info("loading checkpoint") | |
ckpt = torch.load(ckpt_file, map_location="cpu") | |
# load the model state dict | |
model.load_state_dict(ckpt["model"]) | |
logger.info("loaded checkpoint done.") | |
if args.fuse: | |
logger.info("\tFusing model...") | |
model = fuse_model(model) | |
if args.fp16: | |
model = model.half() # to FP16 | |
if args.trt: | |
assert not args.fuse, "TensorRT model is not support model fusing!" | |
trt_file = os.path.join(file_name, "model_trt.pth") | |
assert os.path.exists( | |
trt_file | |
), "TensorRT model is not found!\n Run python3 tools/trt.py first!" | |
model.head.decode_in_inference = False | |
decoder = model.head.decode_outputs | |
logger.info("Using TensorRT to inference") | |
else: | |
trt_file = None | |
decoder = None | |
predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16) | |
current_time = time.localtime() | |
if args.demo == "image": | |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result) | |
elif args.demo == "video" or args.demo == "webcam": | |
imageflow_demo(predictor, vis_folder, current_time, args) | |
if __name__ == "__main__": | |
args = make_parser().parse_args() | |
exp = get_exp(args.exp_file, args.name) | |
main(exp, args) | |