ObjCtrl-2.5D / cameractrl /models /feature_extractor.py
wzhouxiff
init
38e3f9b
raw
history blame
15.5 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from easydict import EasyDict as edict
from einops import rearrange
from sklearn.cluster import SpectralClustering
from spatracker.blocks import Lie
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
from spatracker.blocks import (
BasicEncoder,
CorrBlock,
EUpdateFormer,
FusionFormer,
pix2cam,
cam2pix,
edgeMat,
VitEncoder,
DPTEnc,
DPT_DINOv2,
Dinov2
)
from spatracker.feature_net import (
LocalSoftSplat
)
from spatracker.model_utils import (
meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA
)
from spatracker.embeddings import (
get_2d_embedding,
get_3d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_3d_sincos_pos_embed_from_grid,
Embedder_Fourier,
)
import numpy as np
from spatracker.softsplat import softsplat
torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape,
grid_center=(0, 0), device="cuda"):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2,
interp_shape[0] / 2], device=device)[
None, None
]
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device=device
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy
def sample_pos_embed(grid_size, embed_dim, coords):
if coords.shape[-1] == 2:
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,
grid_size=grid_size)
pos_embed = (
torch.from_numpy(pos_embed)
.reshape(grid_size[0], grid_size[1], embed_dim)
.float()
.unsqueeze(0)
.to(coords.device)
)
sampled_pos_embed = bilinear_sample2d(
pos_embed.permute(0, 3, 1, 2),
coords[:, 0, :, 0], coords[:, 0, :, 1]
)
elif coords.shape[-1] == 3:
sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(
embed_dim, coords[:, :1, ...]
).float()[:,0,...].permute(0, 2, 1)
return sampled_pos_embed
class FeatureExtractor(nn.Module):
def __init__(
self,
S=8,
stride=8,
add_space_attn=True,
num_heads=8,
hidden_size=384,
space_depth=12,
time_depth=12,
depth_extend_margin = 0.2,
args=edict({})
):
super(FeatureExtractor, self).__init__()
# step1: config the arch of the model
self.args=args
# step1.1: config the default value of the model
if getattr(args, "depth_color", None) == None:
self.args.depth_color = False
if getattr(args, "if_ARAP", None) == None:
self.args.if_ARAP = True
if getattr(args, "flash_attn", None) == None:
self.args.flash_attn = True
if getattr(args, "backbone", None) == None:
self.args.backbone = "CNN"
if getattr(args, "Nblock", None) == None:
self.args.Nblock = 0
if getattr(args, "Embed3D", None) == None:
self.args.Embed3D = True
# step1.2: config the model parameters
self.S = S
self.stride = stride
self.hidden_dim = 256
self.latent_dim = latent_dim = 128
self.b_latent_dim = self.latent_dim//3
self.corr_levels = 4
self.corr_radius = 3
self.add_space_attn = add_space_attn
self.lie = Lie()
self.depth_extend_margin = depth_extend_margin
# step2: config the model components
# @Encoder
self.fnet = BasicEncoder(input_dim=3,
output_dim=self.latent_dim, norm_fn="instance", dropout=0,
stride=stride, Embed3D=False
)
# conv head for the tri-plane features
self.headyz = nn.Sequential(
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
self.headxz = nn.Sequential(
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
# @UpdateFormer
self.updateformer = EUpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=456,
hidden_size=hidden_size,
num_heads=num_heads,
output_dim=latent_dim + 3,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
flash=getattr(self.args, "flash_attn", True)
)
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
self.norm = nn.GroupNorm(1, self.latent_dim)
self.ffeat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.ffeatyz_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.ffeatxz_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
#TODO @NeuralArap: optimize the arap
self.embed_traj = Embedder_Fourier(
input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True
)
self.embed3d = Embedder_Fourier(
input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True
)
self.embedConv = nn.Conv2d(self.latent_dim+63,
self.latent_dim, 3, padding=1)
# @Vis_predictor
self.vis_predictor = nn.Sequential(
nn.Linear(128, 1),
)
self.embedProj = nn.Linear(63, 456)
self.zeroMLPflow = nn.Linear(195, 130)
def prepare_track(self, rgbds, queries):
"""
NOTE:
Normalized the rgbs and sorted the queries via their first appeared time
Args:
rgbds: the input rgbd images (B T 4 H W)
queries: the input queries (B N 4)
Return:
rgbds: the normalized rgbds (B T 4 H W)
queries: the sorted queries (B N 4)
track_mask:
"""
assert (rgbds.shape[2]==4) and (queries.shape[2]==4)
#Step1: normalize the rgbs input
device = rgbds.device
rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0
B, T, C, H, W = rgbds.shape
B, N, __ = queries.shape
self.traj_e = torch.zeros((B, T, N, 3), device=device)
self.vis_e = torch.zeros((B, T, N), device=device)
#Step2: sort the points via their first appeared time
first_positive_inds = queries[0, :, 0].long()
__, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)
inv_sort_inds = torch.argsort(sort_inds, dim=0)
first_positive_sorted_inds = first_positive_inds[sort_inds]
# check if can be inverse
assert torch.allclose(
first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]
)
# filter those points never appear points during 1 - T
ind_array = torch.arange(T, device=device)
ind_array = ind_array[None, :, None].repeat(B, 1, N)
track_mask = (ind_array >=
first_positive_inds[None, None, :]).unsqueeze(-1)
# scale the coords_init
coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(
1, self.S, 1, 1
)
coords_init[..., :2] /= float(self.stride)
#Step3: initial the regular grid
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
gridx, gridy = torch.meshgrid(gridx, gridy)
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
2, 1, 0
)
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
# Step4: initial traj for neural arap
T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1
T_series = T_series.repeat(B, 1, N, 1)
# get the 3d traj in the camera coordinates
intr_init = self.intrs[:,queries[0,:,0].long()]
Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) # [B S N 3]
#torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1
Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()
Traj_series = torch.cat([T_series, Traj_series], dim=-1)
# get the indicator for the neural arap
Traj_mask = -1e2*torch.ones_like(T_series)
Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)
return (
rgbds,
first_positive_inds,
first_positive_sorted_inds,
sort_inds, inv_sort_inds,
track_mask, gridxy, coords_init[..., sort_inds, :].clone(),
vis_init, Traj_series[..., sort_inds, :].clone()
)
def sample_trifeat(self, t,
coords,
featMapxy,
featMapyz,
featMapxz):
"""
Sample the features from the 5D triplane feature map 3*(B S C H W)
Args:
t: the time index
coords: the coordinates of the points B S N 3
featMapxy: the feature map B S C Hx Wy
featMapyz: the feature map B S C Hy Wz
featMapxz: the feature map B S C Hx Wz
"""
# get xy_t yz_t xz_t
queried_t = t.reshape(1, 1, -1, 1)
xy_t = torch.cat(
[queried_t, coords[..., [0,1]]],
dim=-1
)
yz_t = torch.cat(
[queried_t, coords[..., [1, 2]]],
dim=-1
)
xz_t = torch.cat(
[queried_t, coords[..., [0, 2]]],
dim=-1
)
featxy_init = sample_features5d(featMapxy, xy_t)
featyz_init = sample_features5d(featMapyz, yz_t)
featxz_init = sample_features5d(featMapxz, xz_t)
featxy_init = featxy_init.repeat(1, self.S, 1, 1)
featyz_init = featyz_init.repeat(1, self.S, 1, 1)
featxz_init = featxz_init.repeat(1, self.S, 1, 1)
return featxy_init, featyz_init, featxz_init
def forward(self, rgbds, queries, num_levels=4, feat_init=None,
is_train=False, intrs=None, wind_S=None):
'''
queries: given trajs (B, f, N, 3) [x, y, z], x, y in camera coordinate, z in depth (need to be normalized)
vis_init: visibility of the points (B, f, N) , 0 for invisible, 1 for visible
'''
B, T, C, H, W = rgbds.shape
Dz = W//self.stride
rgbs_ = rgbds[:, :, :3,...]
depth_all = rgbds[:, :, 3,...]
d_near = self.d_near = depth_all[depth_all>0.01].min().item()
d_far = self.d_far = depth_all[depth_all>0.01].max().item()
d_near_z = queries.reshape(B, -1, 3)[:, :, 2].min().item()
d_far_z = queries.reshape(B, -1, 3)[:, :, 2].max().item()
d_near = min(d_near, d_near_z)
d_far = max(d_far, d_far_z)
d_near = min(d_near - self.depth_extend_margin, 0.01)
d_far = d_far + self.depth_extend_margin
depths = (depth_all - d_near)/(d_far-d_near)
depths_dn = nn.functional.interpolate(
depths, scale_factor=1.0 / self.stride, mode="nearest")
depths_dnG = depths_dn*Dz
#Step3: initial the regular grid
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
gridx, gridy = torch.meshgrid(gridx, gridy)
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
2, 1, 0
) # 2 H W
gridxyz = torch.cat([gridxy[None,...].repeat(
depths_dn.shape[0],1,1,1), depths_dnG], dim=1)
Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2]
Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2]
if getattr(self.args, "Embed3D", None) == True:
gridxyz_nm = gridxyz.clone()
gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min())
gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min())
gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min())
gridxyz_nm = 2*(gridxyz_nm-0.5)
_,_,h4,w4 = gridxyz_nm.shape
gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3)
featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2)
if fmaps_ is None:
fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1)
fmaps_ = self.embedConv(fmaps_)
else:
fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1)
fmaps_new = self.embedConv(fmaps_new)
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], fmaps_new], dim=0
)
else:
if fmaps_ is None:
fmaps_ = self.fnet(rgbs_)
else:
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
)
fmapXY = fmaps_[:, :self.latent_dim].reshape(
B, T, self.latent_dim, H // self.stride, W // self.stride
)
fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,
strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride)
fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,
strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride)
fmapYZ = self.headyz(fmapYZ)[None, ...]
fmapXZ = self.headxz(fmapXZ)[None, ...]
# scale the coords_init
coords_init = queries[:, :1] # B 1 N 3, the first frame
coords_init[..., :2] /= float(self.stride)
(featxy_init,
featyz_init,
featxz_init) = self.sample_trifeat(
t=torch.zeros(B*queries.shape[2]),featMapxy=fmapXY,
featMapyz=fmapYZ,featMapxz=fmapXZ,
coords = coords_init # B 1 N 3
)
return torch.stack([featxy_init, featyz_init, featxz_init], dim=-1)