Spaces:
Running
on
A10G
Running
on
A10G
File size: 3,856 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
import torch
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@torch.no_grad()
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
"""
Converts a flow to an RGB image.
Args:
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
Returns:
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
"""
if flow.dtype != torch.float:
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
orig_shape = flow.shape
if flow.ndim == 3:
flow = flow[None] # Add batch dim
if flow.ndim != 4 or flow.shape[1] != 2:
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
epsilon = torch.finfo((flow).dtype).eps
normalized_flow = flow / (max_norm + epsilon)
img = _normalized_flow_to_image(normalized_flow)
if len(orig_shape) == 3:
img = img[0] # Remove batch dim
return img
@torch.no_grad()
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""
Converts a batch of normalized flow to an RGB image.
Args:
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
Returns:
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
"""
N, _, H, W = normalized_flow.shape
device = normalized_flow.device
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
colorwheel = _make_colorwheel().to(device) # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
fk = (a + 1) / 2 * (num_cols - 1)
k0 = torch.floor(fk).to(torch.long)
k1 = k0 + 1
k1[k1 == num_cols] = 0
f = fk - k0
for c in range(colorwheel.shape[1]):
tmp = colorwheel[:, c]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
col = 1 - norm * (1 - col)
flow_image[:, c, :, :] = torch.floor(255. * col)
return flow_image
@torch.no_grad()
def _make_colorwheel() -> torch.Tensor:
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
Returns:
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
col = col + RY
# YG
colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
colorwheel[col : col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col : col + GC, 1] = 255
colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
col = col + GC
# CB
colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
colorwheel[col : col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col : col + BM, 2] = 255
colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
col = col + BM
# MR
colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
colorwheel[col : col + MR, 0] = 255
return colorwheel
|