# Copyright 2024 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Pytorch model utilities.""" import math from typing import Any, Sequence, Union import numpy as np import torch import torch.nn.functional as F def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor: """Resizes a 5D tensor using bilinear interpolation. Args: x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is time, W is width, H is height, and C is the number of channels. resolution: The target resolution as a tuple (new_width, new_height). Returns: The resized tensor. """ b, t, h, w, c = x.size() x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w) x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False) b, _, h, w = x.size() x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2) return x def map_coordinates_3d( feats: torch.Tensor, coordinates: torch.Tensor ) -> torch.Tensor: """Maps 3D coordinates to corresponding features using bilinear interpolation. Args: feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch size, W is width, H is height, D is depth, and C is the number of channels. coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the number of coordinates and the last dimension represents (W, H, D) coordinates. Returns: The mapped features tensor. """ x = feats.permute(0, 4, 1, 2, 3) y = coordinates[:, :, None, None, :].float().clone() y[..., 0] = y[..., 0] + 0.5 y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1 y = torch.flip(y, dims=(-1,)) out = ( F.grid_sample( x, y, mode='bilinear', align_corners=False, padding_mode='border' ) .squeeze(dim=(3, 4)) .permute(0, 2, 1) ) return out def map_coordinates_2d( feats: torch.Tensor, coordinates: torch.Tensor ) -> torch.Tensor: """Maps 2D coordinates to feature maps using bilinear interpolation. The function performs bilinear interpolation on the feature maps (`feats`) at the specified `coordinates`. The coordinates are normalized between -1 and 1 The result is a tensor of sampled features corresponding to these coordinates. Args: feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature maps, where N is the batch size, T is the number of frames, H and W are height and width, and C is the number of channels. coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing coordinates, where N is the batch size, P is the number of points, T is the number of frames, S is the number of samples, and XY represents the 2D coordinates. Returns: Tensor: A 5D tensor of the sampled features corresponding to the given coordinates, of shape (N, P, T, S, C). """ n, t, h, w, c = feats.shape x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w) n, p, t, s, xy = coordinates.shape y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy) y = 2 * (y / torch.tensor([h, w], device=feats.device)) - 1 y = torch.flip(y, dims=(-1,)).float() out = F.grid_sample( x, y, mode='bilinear', align_corners=False, padding_mode='zeros' ) _, c, _, _ = out.shape out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4) return out def soft_argmax_heatmap_batched(softmax_val, threshold=5): """Test if two image resolutions are the same.""" b, h, w, d1, d2 = softmax_val.shape y, x = torch.meshgrid( torch.arange(d1, device=softmax_val.device), torch.arange(d2, device=softmax_val.device), indexing='ij', ) coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device) softmax_val_flat = softmax_val.reshape(b, h, w, -1) argmax_pos = torch.argmax(softmax_val_flat, dim=-1) pos = coords.reshape(-1, 2)[argmax_pos] valid = ( torch.sum( torch.square( coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :] ), dim=-1, keepdims=True, ) < threshold**2 ) weighted_sum = torch.sum( coords[None, None, None, :, :, :] * valid * softmax_val[:, :, :, :, :, None], dim=(3, 4), ) sum_of_weights = torch.maximum( torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)), torch.tensor(1e-12, device=softmax_val.device), ) return weighted_sum / sum_of_weights def heatmaps_to_points( all_pairs_softmax, image_shape, threshold=5, query_points=None, ): """Convert heatmaps to points using soft argmax.""" out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold) feature_grid_shape = all_pairs_softmax.shape[1:] # Note: out_points is now [x, y]; we need to divide by [width, height]. # image_shape[3] is width and image_shape[2] is height. out_points = convert_grid_coordinates( out_points, feature_grid_shape[3:1:-1], image_shape[3:1:-1], ) assert feature_grid_shape[1] == image_shape[1] if query_points is not None: # The [..., 0:1] is because we only care about the frame index. query_frame = convert_grid_coordinates( query_points.detach(), image_shape[1:4], feature_grid_shape[1:4], coordinate_format='tyx', )[..., 0:1] query_frame = torch.round(query_frame) frame_indices = torch.arange(image_shape[1], device=query_frame.device)[ None, None, : ] is_query_point = query_frame == frame_indices is_query_point = is_query_point[:, :, :, None] out_points = ( out_points * ~is_query_point + torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2] * is_query_point ) return out_points def is_same_res(r1, r2): """Test if two image resolutions are the same.""" return all([x == y for x, y in zip(r1, r2)]) def convert_grid_coordinates( coords: torch.Tensor, input_grid_size: Sequence[int], output_grid_size: Sequence[int], coordinate_format: str = 'xy', ) -> torch.Tensor: """Convert grid coordinates to correct format.""" if isinstance(input_grid_size, tuple): input_grid_size = torch.tensor(input_grid_size, device=coords.device) if isinstance(output_grid_size, tuple): output_grid_size = torch.tensor(output_grid_size, device=coords.device) if coordinate_format == 'xy': if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: raise ValueError( 'If coordinate_format is xy, the shapes must be length 2.' ) elif coordinate_format == 'tyx': if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: raise ValueError( 'If coordinate_format is tyx, the shapes must be length 3.' ) if input_grid_size[0] != output_grid_size[0]: raise ValueError('converting frame count is not supported.') else: raise ValueError('Recognized coordinate formats are xy and tyx.') position_in_grid = coords position_in_grid = position_in_grid * output_grid_size / input_grid_size return position_in_grid def generate_default_resolutions(full_size, train_size, num_levels=None): """Generate a list of logarithmically-spaced resolutions. Generated resolutions are between train_size and full_size, inclusive, with num_levels different resolutions total. Useful for generating the input to refinement_resolutions in PIPs. Args: full_size: 2-tuple of ints. The full image size desired. train_size: 2-tuple of ints. The smallest refinement level. Should typically match the training resolution, which is (256, 256) for TAPIR. num_levels: number of levels. Typically each resolution should be less than twice the size of prior resolutions. Returns: A list of resolutions. """ if all([x == y for x, y in zip(train_size, full_size)]): return [train_size] if num_levels is None: size_ratio = np.array(full_size) / np.array(train_size) num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1) if num_levels <= 1: return [train_size] h, w = full_size[0:2] if h % 8 != 0 or w % 8 != 0: print( 'Warning: output size is not a multiple of 8. Final layer ' + 'will round size down.' ) ll_h, ll_w = train_size[0:2] sizes = [] for i in range(num_levels): size = ( int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8, int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8, ) sizes.append(size) return sizes class Conv2dSamePadding(torch.nn.Conv2d): def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: ih, iw = x.size()[-2:] pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) if pad_h > 0 or pad_w > 0: x = F.pad( x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] ) return F.conv2d( x, self.weight, self.bias, self.stride, # self.padding, 0, self.dilation, self.groups, )