|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pytorch model utilities.""" |
|
import math |
|
from typing import Any, Sequence, Union |
|
from einshape.src import abstract_ops |
|
from einshape.src import backend |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor: |
|
"""Resizes a 5D tensor using bilinear interpolation. |
|
|
|
Args: |
|
x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is |
|
time, W is width, H is height, and C is the number of channels. |
|
resolution: The target resolution as a tuple (new_width, new_height). |
|
|
|
Returns: |
|
The resized tensor. |
|
""" |
|
b, t, h, w, c = x.size() |
|
x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w) |
|
x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False) |
|
b, _, h, w = x.size() |
|
x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2) |
|
return x |
|
|
|
|
|
def map_coordinates_3d( |
|
feats: torch.Tensor, coordinates: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Maps 3D coordinates to corresponding features using bilinear interpolation. |
|
|
|
Args: |
|
feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch |
|
size, W is width, H is height, D is depth, and C is the number of |
|
channels. |
|
coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the |
|
number of coordinates and the last dimension represents (W, H, D) |
|
coordinates. |
|
|
|
Returns: |
|
The mapped features tensor. |
|
""" |
|
x = feats.permute(0, 4, 1, 2, 3) |
|
y = coordinates[:, :, None, None, :].float().clone() |
|
y[..., 0] = y[..., 0] + 0.5 |
|
y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1 |
|
y = torch.flip(y, dims=(-1,)) |
|
out = ( |
|
F.grid_sample( |
|
x, y, mode='bilinear', align_corners=False, padding_mode='border' |
|
) |
|
.squeeze(dim=(3, 4)) |
|
.permute(0, 2, 1) |
|
) |
|
return out |
|
|
|
|
|
def map_coordinates_2d( |
|
feats: torch.Tensor, coordinates: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Maps 2D coordinates to feature maps using bilinear interpolation. |
|
|
|
The function performs bilinear interpolation on the feature maps (`feats`) |
|
at the specified `coordinates`. The coordinates are normalized between |
|
-1 and 1 The result is a tensor of sampled features corresponding |
|
to these coordinates. |
|
|
|
Args: |
|
feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature |
|
maps, where N is the batch size, T is the number of frames, H and W are |
|
height and width, and C is the number of channels. |
|
coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing |
|
coordinates, where N is the batch size, P is the number of points, T is |
|
the number of frames, S is the number of samples, and XY represents the 2D |
|
coordinates. |
|
|
|
Returns: |
|
Tensor: A 5D tensor of the sampled features corresponding to the |
|
given coordinates, of shape (N, P, T, S, C). |
|
""" |
|
n, t, h, w, c = feats.shape |
|
x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w) |
|
|
|
n, p, t, s, xy = coordinates.shape |
|
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy) |
|
y = 2 * (y / h) - 1 |
|
y = torch.flip(y, dims=(-1,)).float() |
|
|
|
out = F.grid_sample( |
|
x, y, mode='bilinear', align_corners=False, padding_mode='zeros' |
|
) |
|
_, c, _, _ = out.shape |
|
out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4) |
|
|
|
return out |
|
|
|
|
|
def soft_argmax_heatmap_batched(softmax_val, threshold=5): |
|
"""Test if two image resolutions are the same.""" |
|
b, h, w, d1, d2 = softmax_val.shape |
|
y, x = torch.meshgrid( |
|
torch.arange(d1, device=softmax_val.device), |
|
torch.arange(d2, device=softmax_val.device), |
|
indexing='ij', |
|
) |
|
coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device) |
|
softmax_val_flat = softmax_val.reshape(b, h, w, -1) |
|
argmax_pos = torch.argmax(softmax_val_flat, dim=-1) |
|
|
|
pos = coords.reshape(-1, 2)[argmax_pos] |
|
valid = ( |
|
torch.sum( |
|
torch.square( |
|
coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :] |
|
), |
|
dim=-1, |
|
keepdims=True, |
|
) |
|
< threshold**2 |
|
) |
|
|
|
weighted_sum = torch.sum( |
|
coords[None, None, None, :, :, :] |
|
* valid |
|
* softmax_val[:, :, :, :, :, None], |
|
dim=(3, 4), |
|
) |
|
sum_of_weights = torch.maximum( |
|
torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)), |
|
torch.tensor(1e-12, device=softmax_val.device), |
|
) |
|
return weighted_sum / sum_of_weights |
|
|
|
|
|
def heatmaps_to_points( |
|
all_pairs_softmax, |
|
image_shape, |
|
threshold=5, |
|
query_points=None, |
|
): |
|
"""Convert heatmaps to points using soft argmax.""" |
|
|
|
out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold) |
|
feature_grid_shape = all_pairs_softmax.shape[1:] |
|
|
|
|
|
out_points = convert_grid_coordinates( |
|
out_points, |
|
feature_grid_shape[3:1:-1], |
|
image_shape[3:1:-1], |
|
) |
|
assert feature_grid_shape[1] == image_shape[1] |
|
if query_points is not None: |
|
|
|
query_frame = convert_grid_coordinates( |
|
query_points.detach(), |
|
image_shape[1:4], |
|
feature_grid_shape[1:4], |
|
coordinate_format='tyx', |
|
)[..., 0:1] |
|
|
|
query_frame = torch.round(query_frame) |
|
frame_indices = torch.arange(image_shape[1], device=query_frame.device)[ |
|
None, None, : |
|
] |
|
is_query_point = query_frame == frame_indices |
|
|
|
is_query_point = is_query_point[:, :, :, None] |
|
out_points = ( |
|
out_points * ~is_query_point |
|
+ torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2] |
|
* is_query_point |
|
) |
|
|
|
return out_points |
|
|
|
|
|
def is_same_res(r1, r2): |
|
"""Test if two image resolutions are the same.""" |
|
return all([x == y for x, y in zip(r1, r2)]) |
|
|
|
|
|
def convert_grid_coordinates( |
|
coords: torch.Tensor, |
|
input_grid_size: Sequence[int], |
|
output_grid_size: Sequence[int], |
|
coordinate_format: str = 'xy', |
|
) -> torch.Tensor: |
|
"""Convert grid coordinates to correct format.""" |
|
if isinstance(input_grid_size, tuple): |
|
input_grid_size = torch.tensor(input_grid_size, device=coords.device) |
|
if isinstance(output_grid_size, tuple): |
|
output_grid_size = torch.tensor(output_grid_size, device=coords.device) |
|
|
|
if coordinate_format == 'xy': |
|
if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: |
|
raise ValueError( |
|
'If coordinate_format is xy, the shapes must be length 2.' |
|
) |
|
elif coordinate_format == 'tyx': |
|
if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: |
|
raise ValueError( |
|
'If coordinate_format is tyx, the shapes must be length 3.' |
|
) |
|
if input_grid_size[0] != output_grid_size[0]: |
|
raise ValueError('converting frame count is not supported.') |
|
else: |
|
raise ValueError('Recognized coordinate formats are xy and tyx.') |
|
|
|
position_in_grid = coords |
|
position_in_grid = position_in_grid * output_grid_size / input_grid_size |
|
|
|
return position_in_grid |
|
|
|
|
|
class _JaxBackend(backend.Backend[torch.Tensor]): |
|
"""Einshape implementation for PyTorch.""" |
|
|
|
|
|
|
|
def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor: |
|
return x.reshape(op.shape) |
|
|
|
def transpose( |
|
self, x: torch.Tensor, op: abstract_ops.Transpose |
|
) -> torch.Tensor: |
|
return x.permute(op.perm) |
|
|
|
def broadcast( |
|
self, x: torch.Tensor, op: abstract_ops.Broadcast |
|
) -> torch.Tensor: |
|
shape = op.transform_shape(x.shape) |
|
for axis_position in sorted(op.axis_sizes.keys()): |
|
x = x.unsqueeze(axis_position) |
|
return x.expand(shape) |
|
|
|
|
|
def einshape( |
|
equation: str, value: Union[torch.Tensor, Any], **index_sizes: int |
|
) -> torch.Tensor: |
|
"""Reshapes `value` according to the given Shape Equation. |
|
|
|
Args: |
|
equation: The Shape Equation specifying the index regrouping and reordering. |
|
value: Input tensor, or tensor-like object. |
|
**index_sizes: Sizes of indices, where they cannot be inferred from |
|
`input_shape`. |
|
|
|
Returns: |
|
Tensor derived from `value` by reshaping as specified by `equation`. |
|
""" |
|
if not isinstance(value, torch.Tensor): |
|
value = torch.tensor(value) |
|
return _JaxBackend().exec(equation, value, value.shape, **index_sizes) |
|
|
|
|
|
def generate_default_resolutions(full_size, train_size, num_levels=None): |
|
"""Generate a list of logarithmically-spaced resolutions. |
|
|
|
Generated resolutions are between train_size and full_size, inclusive, with |
|
num_levels different resolutions total. Useful for generating the input to |
|
refinement_resolutions in PIPs. |
|
|
|
Args: |
|
full_size: 2-tuple of ints. The full image size desired. |
|
train_size: 2-tuple of ints. The smallest refinement level. Should |
|
typically match the training resolution, which is (256, 256) for TAPIR. |
|
num_levels: number of levels. Typically each resolution should be less than |
|
twice the size of prior resolutions. |
|
|
|
Returns: |
|
A list of resolutions. |
|
""" |
|
if all([x == y for x, y in zip(train_size, full_size)]): |
|
return [train_size] |
|
|
|
if num_levels is None: |
|
size_ratio = np.array(full_size) / np.array(train_size) |
|
num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1) |
|
|
|
if num_levels <= 1: |
|
return [train_size] |
|
|
|
h, w = full_size[0:2] |
|
if h % 8 != 0 or w % 8 != 0: |
|
print( |
|
'Warning: output size is not a multiple of 8. Final layer ' |
|
+ 'will round size down.' |
|
) |
|
ll_h, ll_w = train_size[0:2] |
|
|
|
sizes = [] |
|
for i in range(num_levels): |
|
size = ( |
|
int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8, |
|
int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8, |
|
) |
|
sizes.append(size) |
|
return sizes |
|
|
|
|
|
class Conv2dSamePadding(torch.nn.Conv2d): |
|
|
|
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: |
|
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
ih, iw = x.size()[-2:] |
|
|
|
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) |
|
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) |
|
|
|
if pad_h > 0 or pad_w > 0: |
|
x = F.pad( |
|
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] |
|
) |
|
return F.conv2d( |
|
x, |
|
self.weight, |
|
self.bias, |
|
self.stride, |
|
|
|
0, |
|
self.dilation, |
|
self.groups, |
|
) |