Shane922's picture
update video writer
bddb8a1
# Copyright (2025) Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import Compose
import cv2
from tqdm import tqdm
import numpy as np
import gc
from .dinov2 import DINOv2
from .dpt_temporal import DPTHeadTemporal
from .util.transform import Resize, NormalizeImage, PrepareForNet
from utils.util import compute_scale_and_shift, get_interpolate_frames
# infer settings, do not change
INFER_LEN = 32
OVERLAP = 10
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
INTERP_LEN = 8
class VideoDepthAnything(nn.Module):
def __init__(
self,
encoder='vitl',
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False,
num_frames=32,
pe='ape'
):
super(VideoDepthAnything, self).__init__()
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23]
}
self.encoder = encoder
self.pretrained = DINOv2(model_name=encoder)
self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
def forward(self, x):
B, T, C, H, W = x.shape
patch_h, patch_w = H // 14, W // 14
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
depth = self.head(features, patch_h, patch_w, T)
depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
depth = F.relu(depth)
return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
frame_height, frame_width = frames[0].shape[:2]
ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
if ratio > 1.78:
input_size = int(input_size * 1.78 / ratio)
input_size = round(input_size / 14) * 14
print(f'==> infer lower bound input size: {input_size}')
transform = Compose([
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
frame_list = [frames[i] for i in range(frames.shape[0])]
frame_step = INFER_LEN - OVERLAP
org_video_len = len(frame_list)
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
depth_list = []
pre_input = None
for frame_id in tqdm(range(0, org_video_len, frame_step)):
cur_list = []
for i in range(INFER_LEN):
cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
cur_input = torch.cat(cur_list, dim=1).to(device)
if pre_input is not None:
cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
with torch.no_grad():
depth = self.forward(cur_input) # depth shape: [1, T, H, W]
depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
pre_input = cur_input
del frame_list
gc.collect()
depth_list_aligned = []
ref_align = []
align_len = OVERLAP - INTERP_LEN
kf_align_list = KEYFRAMES[:align_len]
for frame_id in range(0, len(depth_list), INFER_LEN):
if len(depth_list_aligned) == 0:
depth_list_aligned += depth_list[:INFER_LEN]
for kf_id in kf_align_list:
ref_align.append(depth_list[frame_id+kf_id])
else:
curr_align = []
for i in range(len(kf_align_list)):
curr_align.append(depth_list[frame_id+i])
scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
np.concatenate(ref_align),
np.concatenate(np.ones_like(ref_align)==1))
pre_depth_list = depth_list_aligned[-INTERP_LEN:]
post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
for i in range(len(post_depth_list)):
post_depth_list[i] = post_depth_list[i] * scale + shift
post_depth_list[i][post_depth_list[i]<0] = 0
depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
for i in range(OVERLAP, INFER_LEN):
new_depth = depth_list[frame_id+i] * scale + shift
new_depth[new_depth<0] = 0
depth_list_aligned.append(new_depth)
ref_align = ref_align[:1]
for kf_id in kf_align_list[1:]:
new_depth = depth_list[frame_id+kf_id] * scale + shift
new_depth[new_depth<0] = 0
ref_align.append(new_depth)
depth_list = depth_list_aligned
return np.stack(depth_list[:org_video_len], axis=0), target_fps