Spaces:
Running
Running
import copy | |
import os | |
import cv2 | |
import glob | |
import logging | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from alike import ALike, configs | |
class ImageLoader(object): | |
def __init__(self, filepath: str): | |
self.N = 3000 | |
if filepath.startswith('camera'): | |
camera = int(filepath[6:]) | |
self.cap = cv2.VideoCapture(camera) | |
if not self.cap.isOpened(): | |
raise IOError(f"Can't open camera {camera}!") | |
logging.info(f'Opened camera {camera}') | |
self.mode = 'camera' | |
elif os.path.exists(filepath): | |
if os.path.isfile(filepath): | |
self.cap = cv2.VideoCapture(filepath) | |
if not self.cap.isOpened(): | |
raise IOError(f"Can't open video {filepath}!") | |
rate = self.cap.get(cv2.CAP_PROP_FPS) | |
self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 | |
duration = self.N / rate | |
logging.info(f'Opened video {filepath}') | |
logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s') | |
self.mode = 'video' | |
else: | |
self.images = glob.glob(os.path.join(filepath, '*.png')) + \ | |
glob.glob(os.path.join(filepath, '*.jpg')) + \ | |
glob.glob(os.path.join(filepath, '*.ppm')) | |
self.images.sort() | |
self.N = len(self.images) | |
logging.info(f'Loading {self.N} images') | |
self.mode = 'images' | |
else: | |
raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath) | |
def __getitem__(self, item): | |
if self.mode == 'camera' or self.mode == 'video': | |
if item > self.N: | |
return None | |
ret, img = self.cap.read() | |
if not ret: | |
raise "Can't read image from camera" | |
if self.mode == 'video': | |
self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) | |
elif self.mode == 'images': | |
filename = self.images[item] | |
img = cv2.imread(filename) | |
if img is None: | |
raise Exception('Error reading image %s' % filename) | |
return img | |
def __len__(self): | |
return self.N | |
class SimpleTracker(object): | |
def __init__(self): | |
self.pts_prev = None | |
self.desc_prev = None | |
def update(self, img, pts, desc): | |
N_matches = 0 | |
if self.pts_prev is None: | |
self.pts_prev = pts | |
self.desc_prev = desc | |
out = copy.deepcopy(img) | |
for pt1 in pts: | |
p1 = (int(round(pt1[0])), int(round(pt1[1]))) | |
cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16) | |
else: | |
matches = self.mnn_mather(self.desc_prev, desc) | |
mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]] | |
N_matches = len(matches) | |
out = copy.deepcopy(img) | |
for pt1, pt2 in zip(mpts1, mpts2): | |
p1 = (int(round(pt1[0])), int(round(pt1[1]))) | |
p2 = (int(round(pt2[0])), int(round(pt2[1]))) | |
cv2.line(out, p1, p2, (0, 255, 0), lineType=16) | |
cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16) | |
self.pts_prev = pts | |
self.desc_prev = desc | |
return out, N_matches | |
def mnn_mather(self, desc1, desc2): | |
sim = desc1 @ desc2.transpose() | |
sim[sim < 0.9] = 0 | |
nn12 = np.argmax(sim, axis=1) | |
nn21 = np.argmax(sim, axis=0) | |
ids1 = np.arange(0, sim.shape[0]) | |
mask = (ids1 == nn21[nn12]) | |
matches = np.stack([ids1[mask], nn12[mask]]) | |
return matches.transpose() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='ALike Demo.') | |
parser.add_argument('input', type=str, default='', | |
help='Image directory or movie file or "camera0" (for webcam0).') | |
parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t", | |
help="The model configuration") | |
parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).") | |
parser.add_argument('--top_k', type=int, default=-1, | |
help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)') | |
parser.add_argument('--scores_th', type=float, default=0.2, | |
help='Detector score threshold (default: 0.2).') | |
parser.add_argument('--n_limit', type=int, default=5000, | |
help='Maximum number of keypoints to be detected (default: 5000).') | |
parser.add_argument('--no_display', action='store_true', | |
help='Do not display images to screen. Useful if running remotely (default: False).') | |
parser.add_argument('--no_sub_pixel', action='store_true', | |
help='Do not detect sub-pixel keypoints (default: False).') | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
image_loader = ImageLoader(args.input) | |
model = ALike(**configs[args.model], | |
device=args.device, | |
top_k=args.top_k, | |
scores_th=args.scores_th, | |
n_limit=args.n_limit) | |
tracker = SimpleTracker() | |
if not args.no_display: | |
logging.info("Press 'q' to stop!") | |
cv2.namedWindow(args.model) | |
runtime = [] | |
progress_bar = tqdm(image_loader) | |
for img in progress_bar: | |
if img is None: | |
break | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) | |
kpts = pred['keypoints'] | |
desc = pred['descriptors'] | |
runtime.append(pred['time']) | |
out, N_matches = tracker.update(img, kpts, desc) | |
ave_fps = (1. / np.stack(runtime)).mean() | |
status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" | |
progress_bar.set_description(status) | |
if not args.no_display: | |
cv2.setWindowTitle(args.model, args.model + ': ' + status) | |
cv2.imshow(args.model, out) | |
if cv2.waitKey(1) == ord('q'): | |
break | |
logging.info('Finished!') | |
if not args.no_display: | |
logging.info('Press any key to exit!') | |
cv2.waitKey() | |