|
import sys |
|
sys.path.append('core') |
|
|
|
import argparse |
|
import os |
|
import cv2 |
|
import glob |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
from raft import RAFT |
|
from utils import flow_viz |
|
from utils.utils import InputPadder |
|
|
|
|
|
|
|
DEVICE = 'cuda' |
|
|
|
def load_image(imfile): |
|
img = np.array(Image.open(imfile)).astype(np.uint8) |
|
img = torch.from_numpy(img).permute(2, 0, 1).float() |
|
return img[None].to(DEVICE) |
|
|
|
|
|
def viz(img, flo): |
|
img = img[0].permute(1,2,0).cpu().numpy() |
|
flo = flo[0].permute(1,2,0).cpu().numpy() |
|
|
|
|
|
flo = flow_viz.flow_to_image(flo) |
|
img_flo = np.concatenate([img, flo], axis=0) |
|
|
|
|
|
|
|
|
|
|
|
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) |
|
cv2.waitKey() |
|
|
|
|
|
def demo(args): |
|
model = torch.nn.DataParallel(RAFT(args)) |
|
model.load_state_dict(torch.load(args.model)) |
|
|
|
model = model.module |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
images = glob.glob(os.path.join(args.path, '*.png')) + \ |
|
glob.glob(os.path.join(args.path, '*.jpg')) |
|
|
|
images = sorted(images) |
|
for imfile1, imfile2 in zip(images[:-1], images[1:]): |
|
image1 = load_image(imfile1) |
|
image2 = load_image(imfile2) |
|
|
|
padder = InputPadder(image1.shape) |
|
image1, image2 = padder.pad(image1, image2) |
|
|
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) |
|
viz(image1, flow_up) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model', help="restore checkpoint") |
|
parser.add_argument('--path', help="dataset for evaluation") |
|
parser.add_argument('--small', action='store_true', help='use small model') |
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') |
|
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') |
|
args = parser.parse_args() |
|
|
|
demo(args) |
|
|