|
import copy |
|
import os |
|
import cv2 |
|
import glob |
|
import logging |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from alike import ALike, configs |
|
|
|
|
|
class ImageLoader(object): |
|
def __init__(self, filepath: str): |
|
self.N = 3000 |
|
if filepath.startswith("camera"): |
|
camera = int(filepath[6:]) |
|
self.cap = cv2.VideoCapture(camera) |
|
if not self.cap.isOpened(): |
|
raise IOError(f"Can't open camera {camera}!") |
|
logging.info(f"Opened camera {camera}") |
|
self.mode = "camera" |
|
elif os.path.exists(filepath): |
|
if os.path.isfile(filepath): |
|
self.cap = cv2.VideoCapture(filepath) |
|
if not self.cap.isOpened(): |
|
raise IOError(f"Can't open video {filepath}!") |
|
rate = self.cap.get(cv2.CAP_PROP_FPS) |
|
self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1 |
|
duration = self.N / rate |
|
logging.info(f"Opened video {filepath}") |
|
logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s") |
|
self.mode = "video" |
|
else: |
|
self.images = ( |
|
glob.glob(os.path.join(filepath, "*.png")) |
|
+ glob.glob(os.path.join(filepath, "*.jpg")) |
|
+ glob.glob(os.path.join(filepath, "*.ppm")) |
|
) |
|
self.images.sort() |
|
self.N = len(self.images) |
|
logging.info(f"Loading {self.N} images") |
|
self.mode = "images" |
|
else: |
|
raise IOError( |
|
"Error filepath (camerax/path of images/path of videos): ", filepath |
|
) |
|
|
|
def __getitem__(self, item): |
|
if self.mode == "camera" or self.mode == "video": |
|
if item > self.N: |
|
return None |
|
ret, img = self.cap.read() |
|
if not ret: |
|
raise "Can't read image from camera" |
|
if self.mode == "video": |
|
self.cap.set(cv2.CAP_PROP_POS_FRAMES, item) |
|
elif self.mode == "images": |
|
filename = self.images[item] |
|
img = cv2.imread(filename) |
|
if img is None: |
|
raise Exception("Error reading image %s" % filename) |
|
return img |
|
|
|
def __len__(self): |
|
return self.N |
|
|
|
|
|
class SimpleTracker(object): |
|
def __init__(self): |
|
self.pts_prev = None |
|
self.desc_prev = None |
|
|
|
def update(self, img, pts, desc): |
|
N_matches = 0 |
|
if self.pts_prev is None: |
|
self.pts_prev = pts |
|
self.desc_prev = desc |
|
|
|
out = copy.deepcopy(img) |
|
for pt1 in pts: |
|
p1 = (int(round(pt1[0])), int(round(pt1[1]))) |
|
cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16) |
|
else: |
|
matches = self.mnn_mather(self.desc_prev, desc) |
|
mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]] |
|
N_matches = len(matches) |
|
|
|
out = copy.deepcopy(img) |
|
for pt1, pt2 in zip(mpts1, mpts2): |
|
p1 = (int(round(pt1[0])), int(round(pt1[1]))) |
|
p2 = (int(round(pt2[0])), int(round(pt2[1]))) |
|
cv2.line(out, p1, p2, (0, 255, 0), lineType=16) |
|
cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16) |
|
|
|
self.pts_prev = pts |
|
self.desc_prev = desc |
|
|
|
return out, N_matches |
|
|
|
def mnn_mather(self, desc1, desc2): |
|
sim = desc1 @ desc2.transpose() |
|
sim[sim < 0.9] = 0 |
|
nn12 = np.argmax(sim, axis=1) |
|
nn21 = np.argmax(sim, axis=0) |
|
ids1 = np.arange(0, sim.shape[0]) |
|
mask = ids1 == nn21[nn12] |
|
matches = np.stack([ids1[mask], nn12[mask]]) |
|
return matches.transpose() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="ALike Demo.") |
|
parser.add_argument( |
|
"input", |
|
type=str, |
|
default="", |
|
help='Image directory or movie file or "camera0" (for webcam0).', |
|
) |
|
parser.add_argument( |
|
"--model", |
|
choices=["alike-t", "alike-s", "alike-n", "alike-l"], |
|
default="alike-t", |
|
help="The model configuration", |
|
) |
|
parser.add_argument( |
|
"--device", type=str, default="cuda", help="Running device (default: cuda)." |
|
) |
|
parser.add_argument( |
|
"--top_k", |
|
type=int, |
|
default=-1, |
|
help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)", |
|
) |
|
parser.add_argument( |
|
"--scores_th", |
|
type=float, |
|
default=0.2, |
|
help="Detector score threshold (default: 0.2).", |
|
) |
|
parser.add_argument( |
|
"--n_limit", |
|
type=int, |
|
default=5000, |
|
help="Maximum number of keypoints to be detected (default: 5000).", |
|
) |
|
parser.add_argument( |
|
"--no_display", |
|
action="store_true", |
|
help="Do not display images to screen. Useful if running remotely (default: False).", |
|
) |
|
parser.add_argument( |
|
"--no_sub_pixel", |
|
action="store_true", |
|
help="Do not detect sub-pixel keypoints (default: False).", |
|
) |
|
args = parser.parse_args() |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
image_loader = ImageLoader(args.input) |
|
model = ALike( |
|
**configs[args.model], |
|
device=args.device, |
|
top_k=args.top_k, |
|
scores_th=args.scores_th, |
|
n_limit=args.n_limit, |
|
) |
|
tracker = SimpleTracker() |
|
|
|
if not args.no_display: |
|
logging.info("Press 'q' to stop!") |
|
cv2.namedWindow(args.model) |
|
|
|
runtime = [] |
|
progress_bar = tqdm(image_loader) |
|
for img in progress_bar: |
|
if img is None: |
|
break |
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
pred = model(img_rgb, sub_pixel=not args.no_sub_pixel) |
|
kpts = pred["keypoints"] |
|
desc = pred["descriptors"] |
|
runtime.append(pred["time"]) |
|
|
|
out, N_matches = tracker.update(img, kpts, desc) |
|
|
|
ave_fps = (1.0 / np.stack(runtime)).mean() |
|
status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}" |
|
progress_bar.set_description(status) |
|
|
|
if not args.no_display: |
|
cv2.setWindowTitle(args.model, args.model + ": " + status) |
|
cv2.imshow(args.model, out) |
|
if cv2.waitKey(1) == ord("q"): |
|
break |
|
|
|
logging.info("Finished!") |
|
if not args.no_display: |
|
logging.info("Press any key to exit!") |
|
cv2.waitKey() |
|
|