# Copyright (2025) Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F import torch.nn as nn from torchvision.transforms import Compose import cv2 from tqdm import tqdm import numpy as np import gc from .dinov2 import DINOv2 from .dpt_temporal import DPTHeadTemporal from .util.transform import Resize, NormalizeImage, PrepareForNet from utils.util import compute_scale_and_shift, get_interpolate_frames # infer settings, do not change INFER_LEN = 32 OVERLAP = 10 KEYFRAMES = [0,12,24,25,26,27,28,29,30,31] INTERP_LEN = 8 class VideoDepthAnything(nn.Module): def __init__( self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, num_frames=32, pe='ape' ): super(VideoDepthAnything, self).__init__() self.intermediate_layer_idx = { 'vits': [2, 5, 8, 11], 'vitl': [4, 11, 17, 23] } self.encoder = encoder self.pretrained = DINOv2(model_name=encoder) self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) def forward(self, x): B, T, C, H, W = x.shape patch_h, patch_w = H // 14, W // 14 features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) depth = self.head(features, patch_h, patch_w, T) depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) depth = F.relu(depth) return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W] def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'): frame_height, frame_width = frames[0].shape[:2] ratio = max(frame_height, frame_width) / min(frame_height, frame_width) if ratio > 1.78: input_size = int(input_size * 1.78 / ratio) input_size = round(input_size / 14) * 14 print(f'==> infer lower bound input size: {input_size}') transform = Compose([ Resize( width=input_size, height=input_size, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), PrepareForNet(), ]) frame_list = [frames[i] for i in range(frames.shape[0])] frame_step = INFER_LEN - OVERLAP org_video_len = len(frame_list) append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step) frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len depth_list = [] pre_input = None for frame_id in tqdm(range(0, org_video_len, frame_step)): cur_list = [] for i in range(INFER_LEN): cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)) cur_input = torch.cat(cur_list, dim=1).to(device) if pre_input is not None: cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...] with torch.no_grad(): depth = self.forward(cur_input) # depth shape: [1, T, H, W] depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])] pre_input = cur_input del frame_list gc.collect() depth_list_aligned = [] ref_align = [] align_len = OVERLAP - INTERP_LEN kf_align_list = KEYFRAMES[:align_len] for frame_id in range(0, len(depth_list), INFER_LEN): if len(depth_list_aligned) == 0: depth_list_aligned += depth_list[:INFER_LEN] for kf_id in kf_align_list: ref_align.append(depth_list[frame_id+kf_id]) else: curr_align = [] for i in range(len(kf_align_list)): curr_align.append(depth_list[frame_id+i]) scale, shift = compute_scale_and_shift(np.concatenate(curr_align), np.concatenate(ref_align), np.concatenate(np.ones_like(ref_align)==1)) pre_depth_list = depth_list_aligned[-INTERP_LEN:] post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP] for i in range(len(post_depth_list)): post_depth_list[i] = post_depth_list[i] * scale + shift post_depth_list[i][post_depth_list[i]<0] = 0 depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list) for i in range(OVERLAP, INFER_LEN): new_depth = depth_list[frame_id+i] * scale + shift new_depth[new_depth<0] = 0 depth_list_aligned.append(new_depth) ref_align = ref_align[:1] for kf_id in kf_align_list[1:]: new_depth = depth_list[frame_id+kf_id] * scale + shift new_depth[new_depth<0] = 0 ref_align.append(new_depth) depth_list = depth_list_aligned return np.stack(depth_list[:org_video_len], axis=0), target_fps