Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
from easydict import EasyDict as edict | |
from einops import rearrange | |
from sklearn.cluster import SpectralClustering | |
from spatracker.blocks import Lie | |
import matplotlib.pyplot as plt | |
import cv2 | |
import torch.nn.functional as F | |
from spatracker.blocks import ( | |
BasicEncoder, | |
CorrBlock, | |
EUpdateFormer, | |
FusionFormer, | |
pix2cam, | |
cam2pix, | |
edgeMat, | |
VitEncoder, | |
DPTEnc, | |
DPT_DINOv2, | |
Dinov2 | |
) | |
from spatracker.feature_net import ( | |
LocalSoftSplat | |
) | |
from spatracker.model_utils import ( | |
meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA | |
) | |
from spatracker.embeddings import ( | |
get_2d_embedding, | |
get_3d_embedding, | |
get_1d_sincos_pos_embed_from_grid, | |
get_2d_sincos_pos_embed, | |
get_3d_sincos_pos_embed_from_grid, | |
Embedder_Fourier, | |
) | |
import numpy as np | |
from spatracker.softsplat import softsplat | |
torch.manual_seed(0) | |
def get_points_on_a_grid(grid_size, interp_shape, | |
grid_center=(0, 0), device="cuda"): | |
if grid_size == 1: | |
return torch.tensor([interp_shape[1] / 2, | |
interp_shape[0] / 2], device=device)[ | |
None, None | |
] | |
grid_y, grid_x = meshgrid2d( | |
1, grid_size, grid_size, stack=False, norm=False, device=device | |
) | |
step = interp_shape[1] // 64 | |
if grid_center[0] != 0 or grid_center[1] != 0: | |
grid_y = grid_y - grid_size / 2.0 | |
grid_x = grid_x - grid_size / 2.0 | |
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( | |
interp_shape[0] - step * 2 | |
) | |
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( | |
interp_shape[1] - step * 2 | |
) | |
grid_y = grid_y + grid_center[0] | |
grid_x = grid_x + grid_center[1] | |
xy = torch.stack([grid_x, grid_y], dim=-1).to(device) | |
return xy | |
def sample_pos_embed(grid_size, embed_dim, coords): | |
if coords.shape[-1] == 2: | |
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, | |
grid_size=grid_size) | |
pos_embed = ( | |
torch.from_numpy(pos_embed) | |
.reshape(grid_size[0], grid_size[1], embed_dim) | |
.float() | |
.unsqueeze(0) | |
.to(coords.device) | |
) | |
sampled_pos_embed = bilinear_sample2d( | |
pos_embed.permute(0, 3, 1, 2), | |
coords[:, 0, :, 0], coords[:, 0, :, 1] | |
) | |
elif coords.shape[-1] == 3: | |
sampled_pos_embed = get_3d_sincos_pos_embed_from_grid( | |
embed_dim, coords[:, :1, ...] | |
).float()[:,0,...].permute(0, 2, 1) | |
return sampled_pos_embed | |
class FeatureExtractor(nn.Module): | |
def __init__( | |
self, | |
S=8, | |
stride=8, | |
add_space_attn=True, | |
num_heads=8, | |
hidden_size=384, | |
space_depth=12, | |
time_depth=12, | |
depth_extend_margin = 0.2, | |
args=edict({}) | |
): | |
super(FeatureExtractor, self).__init__() | |
# step1: config the arch of the model | |
self.args=args | |
# step1.1: config the default value of the model | |
if getattr(args, "depth_color", None) == None: | |
self.args.depth_color = False | |
if getattr(args, "if_ARAP", None) == None: | |
self.args.if_ARAP = True | |
if getattr(args, "flash_attn", None) == None: | |
self.args.flash_attn = True | |
if getattr(args, "backbone", None) == None: | |
self.args.backbone = "CNN" | |
if getattr(args, "Nblock", None) == None: | |
self.args.Nblock = 0 | |
if getattr(args, "Embed3D", None) == None: | |
self.args.Embed3D = True | |
# step1.2: config the model parameters | |
self.S = S | |
self.stride = stride | |
self.hidden_dim = 256 | |
self.latent_dim = latent_dim = 128 | |
self.b_latent_dim = self.latent_dim//3 | |
self.corr_levels = 4 | |
self.corr_radius = 3 | |
self.add_space_attn = add_space_attn | |
self.lie = Lie() | |
self.depth_extend_margin = depth_extend_margin | |
# step2: config the model components | |
# @Encoder | |
self.fnet = BasicEncoder(input_dim=3, | |
output_dim=self.latent_dim, norm_fn="instance", dropout=0, | |
stride=stride, Embed3D=False | |
) | |
# conv head for the tri-plane features | |
self.headyz = nn.Sequential( | |
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) | |
self.headxz = nn.Sequential( | |
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) | |
# @UpdateFormer | |
self.updateformer = EUpdateFormer( | |
space_depth=space_depth, | |
time_depth=time_depth, | |
input_dim=456, | |
hidden_size=hidden_size, | |
num_heads=num_heads, | |
output_dim=latent_dim + 3, | |
mlp_ratio=4.0, | |
add_space_attn=add_space_attn, | |
flash=getattr(self.args, "flash_attn", True) | |
) | |
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 | |
self.norm = nn.GroupNorm(1, self.latent_dim) | |
self.ffeat_updater = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim), | |
nn.GELU(), | |
) | |
self.ffeatyz_updater = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim), | |
nn.GELU(), | |
) | |
self.ffeatxz_updater = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim), | |
nn.GELU(), | |
) | |
#TODO @NeuralArap: optimize the arap | |
self.embed_traj = Embedder_Fourier( | |
input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True | |
) | |
self.embed3d = Embedder_Fourier( | |
input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True | |
) | |
self.embedConv = nn.Conv2d(self.latent_dim+63, | |
self.latent_dim, 3, padding=1) | |
# @Vis_predictor | |
self.vis_predictor = nn.Sequential( | |
nn.Linear(128, 1), | |
) | |
self.embedProj = nn.Linear(63, 456) | |
self.zeroMLPflow = nn.Linear(195, 130) | |
def prepare_track(self, rgbds, queries): | |
""" | |
NOTE: | |
Normalized the rgbs and sorted the queries via their first appeared time | |
Args: | |
rgbds: the input rgbd images (B T 4 H W) | |
queries: the input queries (B N 4) | |
Return: | |
rgbds: the normalized rgbds (B T 4 H W) | |
queries: the sorted queries (B N 4) | |
track_mask: | |
""" | |
assert (rgbds.shape[2]==4) and (queries.shape[2]==4) | |
#Step1: normalize the rgbs input | |
device = rgbds.device | |
rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0 | |
B, T, C, H, W = rgbds.shape | |
B, N, __ = queries.shape | |
self.traj_e = torch.zeros((B, T, N, 3), device=device) | |
self.vis_e = torch.zeros((B, T, N), device=device) | |
#Step2: sort the points via their first appeared time | |
first_positive_inds = queries[0, :, 0].long() | |
__, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False) | |
inv_sort_inds = torch.argsort(sort_inds, dim=0) | |
first_positive_sorted_inds = first_positive_inds[sort_inds] | |
# check if can be inverse | |
assert torch.allclose( | |
first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds] | |
) | |
# filter those points never appear points during 1 - T | |
ind_array = torch.arange(T, device=device) | |
ind_array = ind_array[None, :, None].repeat(B, 1, N) | |
track_mask = (ind_array >= | |
first_positive_inds[None, None, :]).unsqueeze(-1) | |
# scale the coords_init | |
coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat( | |
1, self.S, 1, 1 | |
) | |
coords_init[..., :2] /= float(self.stride) | |
#Step3: initial the regular grid | |
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) | |
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) | |
gridx, gridy = torch.meshgrid(gridx, gridy) | |
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( | |
2, 1, 0 | |
) | |
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 | |
# Step4: initial traj for neural arap | |
T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1 | |
T_series = T_series.repeat(B, 1, N, 1) | |
# get the 3d traj in the camera coordinates | |
intr_init = self.intrs[:,queries[0,:,0].long()] | |
Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) # [B S N 3] | |
#torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1 | |
Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float() | |
Traj_series = torch.cat([T_series, Traj_series], dim=-1) | |
# get the indicator for the neural arap | |
Traj_mask = -1e2*torch.ones_like(T_series) | |
Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1) | |
return ( | |
rgbds, | |
first_positive_inds, | |
first_positive_sorted_inds, | |
sort_inds, inv_sort_inds, | |
track_mask, gridxy, coords_init[..., sort_inds, :].clone(), | |
vis_init, Traj_series[..., sort_inds, :].clone() | |
) | |
def sample_trifeat(self, t, | |
coords, | |
featMapxy, | |
featMapyz, | |
featMapxz): | |
""" | |
Sample the features from the 5D triplane feature map 3*(B S C H W) | |
Args: | |
t: the time index | |
coords: the coordinates of the points B S N 3 | |
featMapxy: the feature map B S C Hx Wy | |
featMapyz: the feature map B S C Hy Wz | |
featMapxz: the feature map B S C Hx Wz | |
""" | |
# get xy_t yz_t xz_t | |
queried_t = t.reshape(1, 1, -1, 1) | |
xy_t = torch.cat( | |
[queried_t, coords[..., [0,1]]], | |
dim=-1 | |
) | |
yz_t = torch.cat( | |
[queried_t, coords[..., [1, 2]]], | |
dim=-1 | |
) | |
xz_t = torch.cat( | |
[queried_t, coords[..., [0, 2]]], | |
dim=-1 | |
) | |
featxy_init = sample_features5d(featMapxy, xy_t) | |
featyz_init = sample_features5d(featMapyz, yz_t) | |
featxz_init = sample_features5d(featMapxz, xz_t) | |
featxy_init = featxy_init.repeat(1, self.S, 1, 1) | |
featyz_init = featyz_init.repeat(1, self.S, 1, 1) | |
featxz_init = featxz_init.repeat(1, self.S, 1, 1) | |
return featxy_init, featyz_init, featxz_init | |
def forward(self, rgbds, queries, num_levels=4, feat_init=None, | |
is_train=False, intrs=None, wind_S=None): | |
''' | |
queries: given trajs (B, f, N, 3) [x, y, z], x, y in camera coordinate, z in depth (need to be normalized) | |
vis_init: visibility of the points (B, f, N) , 0 for invisible, 1 for visible | |
''' | |
B, T, C, H, W = rgbds.shape | |
Dz = W//self.stride | |
rgbs_ = rgbds[:, :, :3,...] | |
depth_all = rgbds[:, :, 3,...] | |
d_near = self.d_near = depth_all[depth_all>0.01].min().item() | |
d_far = self.d_far = depth_all[depth_all>0.01].max().item() | |
d_near_z = queries.reshape(B, -1, 3)[:, :, 2].min().item() | |
d_far_z = queries.reshape(B, -1, 3)[:, :, 2].max().item() | |
d_near = min(d_near, d_near_z) | |
d_far = max(d_far, d_far_z) | |
d_near = min(d_near - self.depth_extend_margin, 0.01) | |
d_far = d_far + self.depth_extend_margin | |
depths = (depth_all - d_near)/(d_far-d_near) | |
depths_dn = nn.functional.interpolate( | |
depths, scale_factor=1.0 / self.stride, mode="nearest") | |
depths_dnG = depths_dn*Dz | |
#Step3: initial the regular grid | |
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) | |
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) | |
gridx, gridy = torch.meshgrid(gridx, gridy) | |
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( | |
2, 1, 0 | |
) # 2 H W | |
gridxyz = torch.cat([gridxy[None,...].repeat( | |
depths_dn.shape[0],1,1,1), depths_dnG], dim=1) | |
Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2] | |
Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2] | |
if getattr(self.args, "Embed3D", None) == True: | |
gridxyz_nm = gridxyz.clone() | |
gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min()) | |
gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min()) | |
gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min()) | |
gridxyz_nm = 2*(gridxyz_nm-0.5) | |
_,_,h4,w4 = gridxyz_nm.shape | |
gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3) | |
featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2) | |
if fmaps_ is None: | |
fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1) | |
fmaps_ = self.embedConv(fmaps_) | |
else: | |
fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1) | |
fmaps_new = self.embedConv(fmaps_new) | |
fmaps_ = torch.cat( | |
[fmaps_[self.S // 2 :], fmaps_new], dim=0 | |
) | |
else: | |
if fmaps_ is None: | |
fmaps_ = self.fnet(rgbs_) | |
else: | |
fmaps_ = torch.cat( | |
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 | |
) | |
fmapXY = fmaps_[:, :self.latent_dim].reshape( | |
B, T, self.latent_dim, H // self.stride, W // self.stride | |
) | |
fmapYZ = softsplat(fmapXY[0], Fxy2yz, None, | |
strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride) | |
fmapXZ = softsplat(fmapXY[0], Fxy2xz, None, | |
strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride) | |
fmapYZ = self.headyz(fmapYZ)[None, ...] | |
fmapXZ = self.headxz(fmapXZ)[None, ...] | |
# scale the coords_init | |
coords_init = queries[:, :1] # B 1 N 3, the first frame | |
coords_init[..., :2] /= float(self.stride) | |
(featxy_init, | |
featyz_init, | |
featxz_init) = self.sample_trifeat( | |
t=torch.zeros(B*queries.shape[2]),featMapxy=fmapXY, | |
featMapyz=fmapYZ,featMapxz=fmapXZ, | |
coords = coords_init # B 1 N 3 | |
) | |
return torch.stack([featxy_init, featyz_init, featxz_init], dim=-1) | |