Spaces:
Running
on
L40S
Running
on
L40S
add gradio demo
Browse files- .gitattributes +2 -0
- app.py +168 -0
- assets/example_videos/Tokyo-Walk_rgb.mp4 +3 -0
- assets/example_videos/davis_rollercoaster.mp4 +3 -0
- assets/teaser_video_v2.png +3 -0
- requirements.txt +19 -0
- utils/dc_utils.py +82 -0
- utils/util.py +74 -0
- video_depth_anything/dinov2.py +415 -0
- video_depth_anything/dinov2_layers/__init__.py +11 -0
- video_depth_anything/dinov2_layers/attention.py +83 -0
- video_depth_anything/dinov2_layers/block.py +252 -0
- video_depth_anything/dinov2_layers/drop_path.py +35 -0
- video_depth_anything/dinov2_layers/layer_scale.py +28 -0
- video_depth_anything/dinov2_layers/mlp.py +41 -0
- video_depth_anything/dinov2_layers/patch_embed.py +89 -0
- video_depth_anything/dinov2_layers/swiglu_ffn.py +63 -0
- video_depth_anything/dpt.py +160 -0
- video_depth_anything/dpt_temporal.py +96 -0
- video_depth_anything/motion_module/attention.py +423 -0
- video_depth_anything/motion_module/motion_module.py +288 -0
- video_depth_anything/util/blocks.py +162 -0
- video_depth_anything/util/transform.py +158 -0
- video_depth_anything/video_depth.py +149 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import os
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from video_depth_anything.video_depth import VideoDepthAnything
|
22 |
+
from utils.dc_utils import read_video_frames, vis_sequence_depth, save_video
|
23 |
+
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
examples = [
|
27 |
+
['assets/example_videos/davis_rollercoaster.mp4'],
|
28 |
+
]
|
29 |
+
|
30 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31 |
+
|
32 |
+
model_configs = {
|
33 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
34 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
35 |
+
}
|
36 |
+
|
37 |
+
encoder2name = {
|
38 |
+
'vits': 'Small',
|
39 |
+
'vitl': 'Large',
|
40 |
+
}
|
41 |
+
|
42 |
+
encoder='vitl'
|
43 |
+
model_name = encoder2name[encoder]
|
44 |
+
|
45 |
+
video_depth_anything = VideoDepthAnything(**model_configs[encoder])
|
46 |
+
filepath = hf_hub_download(repo_id=f"depth-anything/Video-Depth-Anything-{model_name}", filename=f"video_depth_anything_{encoder}.pth", repo_type="model")
|
47 |
+
video_depth_anything.load_state_dict(torch.load(filepath, map_location='cpu'))
|
48 |
+
video_depth_anything = video_depth_anything.to(DEVICE).eval()
|
49 |
+
|
50 |
+
|
51 |
+
title = "# Video Depth Anything"
|
52 |
+
description = """Official demo for **Video Depth Anything**.
|
53 |
+
Please refer to our [project page](https://videodepthanything.github.io/), and [github](https://github.com/DepthAnything/Video-Depth-Anything) for more details."""
|
54 |
+
|
55 |
+
|
56 |
+
def infer_video_depth(
|
57 |
+
input_video: str,
|
58 |
+
max_len: int = -1,
|
59 |
+
target_fps: int = -1,
|
60 |
+
max_res: int = 1280,
|
61 |
+
output_dir: str = './outputs',
|
62 |
+
input_size: int = 518,
|
63 |
+
):
|
64 |
+
frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
|
65 |
+
depth_list, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device=DEVICE)
|
66 |
+
depth_list = np.stack(depth_list, axis=0)
|
67 |
+
vis = vis_sequence_depth(depth_list)
|
68 |
+
video_name = os.path.basename(input_video)
|
69 |
+
if not os.path.exists(output_dir):
|
70 |
+
os.makedirs(output_dir)
|
71 |
+
|
72 |
+
processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
|
73 |
+
depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
|
74 |
+
save_video(frames, processed_video_path, fps=fps)
|
75 |
+
save_video(vis, depth_vis_path, fps=fps)
|
76 |
+
|
77 |
+
return [processed_video_path, depth_vis_path]
|
78 |
+
|
79 |
+
|
80 |
+
def construct_demo():
|
81 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
82 |
+
gr.Markdown(title)
|
83 |
+
gr.Markdown(description)
|
84 |
+
gr.Markdown("### Video Depth Prediction demo")
|
85 |
+
|
86 |
+
with gr.Row(equal_height=True):
|
87 |
+
with gr.Column(scale=1):
|
88 |
+
input_video = gr.Video(label="Input Video")
|
89 |
+
|
90 |
+
# with gr.Tab(label="Output"):
|
91 |
+
with gr.Column(scale=2):
|
92 |
+
with gr.Row(equal_height=True):
|
93 |
+
processed_video = gr.Video(
|
94 |
+
label="Preprocessed video",
|
95 |
+
interactive=False,
|
96 |
+
autoplay=True,
|
97 |
+
loop=True,
|
98 |
+
show_share_button=True,
|
99 |
+
scale=5,
|
100 |
+
)
|
101 |
+
depth_vis_video = gr.Video(
|
102 |
+
label="Generated Depth Video",
|
103 |
+
interactive=False,
|
104 |
+
autoplay=True,
|
105 |
+
loop=True,
|
106 |
+
show_share_button=True,
|
107 |
+
scale=5,
|
108 |
+
)
|
109 |
+
|
110 |
+
with gr.Row(equal_height=True):
|
111 |
+
with gr.Column(scale=1):
|
112 |
+
with gr.Row(equal_height=False):
|
113 |
+
with gr.Accordion("Advanced Settings", open=False):
|
114 |
+
max_len = gr.Slider(
|
115 |
+
label="max process length",
|
116 |
+
minimum=-1,
|
117 |
+
maximum=1000,
|
118 |
+
value=-1,
|
119 |
+
step=1,
|
120 |
+
)
|
121 |
+
target_fps = gr.Slider(
|
122 |
+
label="target FPS",
|
123 |
+
minimum=-1,
|
124 |
+
maximum=30,
|
125 |
+
value=15,
|
126 |
+
step=1,
|
127 |
+
)
|
128 |
+
max_res = gr.Slider(
|
129 |
+
label="max side resolution",
|
130 |
+
minimum=480,
|
131 |
+
maximum=1920,
|
132 |
+
value=1280,
|
133 |
+
step=1,
|
134 |
+
)
|
135 |
+
generate_btn = gr.Button("Generate")
|
136 |
+
with gr.Column(scale=2):
|
137 |
+
pass
|
138 |
+
|
139 |
+
gr.Examples(
|
140 |
+
examples=examples,
|
141 |
+
inputs=[
|
142 |
+
input_video,
|
143 |
+
max_len,
|
144 |
+
target_fps,
|
145 |
+
max_res
|
146 |
+
],
|
147 |
+
outputs=[processed_video, depth_vis_video],
|
148 |
+
fn=infer_video_depth,
|
149 |
+
cache_examples="lazy",
|
150 |
+
)
|
151 |
+
|
152 |
+
generate_btn.click(
|
153 |
+
fn=infer_video_depth,
|
154 |
+
inputs=[
|
155 |
+
input_video,
|
156 |
+
max_len,
|
157 |
+
target_fps,
|
158 |
+
max_res
|
159 |
+
],
|
160 |
+
outputs=[processed_video, depth_vis_video],
|
161 |
+
)
|
162 |
+
|
163 |
+
return demo
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
demo = construct_demo()
|
167 |
+
demo.queue()
|
168 |
+
demo.launch(share=True)
|
assets/example_videos/Tokyo-Walk_rgb.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:097f16c33dd8c8d1d2a24d9ea31a90b76bd0ee324b958a47385183e3547a63a8
|
3 |
+
size 2251450
|
assets/example_videos/davis_rollercoaster.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7268cbecd9806a1e90a416de50dc02e50b4ae01428d5971837cf679dd0c87cb8
|
3 |
+
size 1809560
|
assets/teaser_video_v2.png
ADDED
Git LFS Details
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio_imageslider
|
2 |
+
gradio==4.36.0
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
opencv-python
|
6 |
+
matplotlib
|
7 |
+
huggingface_hub
|
8 |
+
typing
|
9 |
+
tempfile
|
10 |
+
pillow
|
11 |
+
mediapy
|
12 |
+
decord
|
13 |
+
xformers
|
14 |
+
einops
|
15 |
+
math
|
16 |
+
functools
|
17 |
+
logging
|
18 |
+
easydict
|
19 |
+
tqdm
|
utils/dc_utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter
|
2 |
+
# SPDX-License-Identifier: MIT License license
|
3 |
+
#
|
4 |
+
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
|
5 |
+
# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
|
6 |
+
from typing import Union, List
|
7 |
+
import tempfile
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import matplotlib.cm as cm
|
11 |
+
import mediapy
|
12 |
+
import torch
|
13 |
+
from decord import VideoReader, cpu
|
14 |
+
|
15 |
+
|
16 |
+
def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1, dataset="open"):
|
17 |
+
|
18 |
+
vid = VideoReader(video_path, ctx=cpu(0))
|
19 |
+
original_height, original_width = vid.get_batch([0]).shape[1:3]
|
20 |
+
height = original_height
|
21 |
+
width = original_width
|
22 |
+
if max_res > 0 and max(height, width) > max_res:
|
23 |
+
scale = max_res / max(original_height, original_width)
|
24 |
+
height = round(original_height * scale)
|
25 |
+
width = round(original_width * scale)
|
26 |
+
|
27 |
+
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
|
28 |
+
|
29 |
+
fps = vid.get_avg_fps() if target_fps == -1 else target_fps
|
30 |
+
stride = round(vid.get_avg_fps() / fps)
|
31 |
+
stride = max(stride, 1)
|
32 |
+
frames_idx = list(range(0, len(vid), stride))
|
33 |
+
if process_length != -1 and process_length < len(frames_idx):
|
34 |
+
frames_idx = frames_idx[:process_length]
|
35 |
+
frames = vid.get_batch(frames_idx).asnumpy()
|
36 |
+
|
37 |
+
return frames, fps
|
38 |
+
|
39 |
+
|
40 |
+
def save_video(
|
41 |
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
|
42 |
+
output_video_path: str = None,
|
43 |
+
fps: int = 10,
|
44 |
+
crf: int = 18,
|
45 |
+
) -> str:
|
46 |
+
if output_video_path is None:
|
47 |
+
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
48 |
+
|
49 |
+
if isinstance(video_frames[0], np.ndarray):
|
50 |
+
video_frames = [frame.astype(np.uint8) for frame in video_frames]
|
51 |
+
|
52 |
+
elif isinstance(video_frames[0], PIL.Image.Image):
|
53 |
+
video_frames = [np.array(frame) for frame in video_frames]
|
54 |
+
mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
|
55 |
+
return output_video_path
|
56 |
+
|
57 |
+
|
58 |
+
class ColorMapper:
|
59 |
+
# a color mapper to map depth values to a certain colormap
|
60 |
+
def __init__(self, colormap: str = "inferno"):
|
61 |
+
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
|
62 |
+
|
63 |
+
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
|
64 |
+
# assert len(image.shape) == 2
|
65 |
+
if v_min is None:
|
66 |
+
v_min = image.min()
|
67 |
+
if v_max is None:
|
68 |
+
v_max = image.max()
|
69 |
+
image = (image - v_min) / (v_max - v_min)
|
70 |
+
image = (image * 255).long()
|
71 |
+
image = self.colormap[image] * 255
|
72 |
+
return image
|
73 |
+
|
74 |
+
|
75 |
+
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
|
76 |
+
visualizer = ColorMapper()
|
77 |
+
if v_min is None:
|
78 |
+
v_min = depths.min()
|
79 |
+
if v_max is None:
|
80 |
+
v_max = depths.max()
|
81 |
+
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
|
82 |
+
return res
|
utils/util.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
def compute_scale_and_shift(prediction, target, mask, scale_only=False):
|
17 |
+
if scale_only:
|
18 |
+
return compute_scale(prediction, target, mask), 0
|
19 |
+
else:
|
20 |
+
return compute_scale_and_shift_full(prediction, target, mask)
|
21 |
+
|
22 |
+
|
23 |
+
def compute_scale(prediction, target, mask):
|
24 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
25 |
+
prediction = prediction.astype(np.float32)
|
26 |
+
target = target.astype(np.float32)
|
27 |
+
mask = mask.astype(np.float32)
|
28 |
+
|
29 |
+
a_00 = np.sum(mask * prediction * prediction)
|
30 |
+
a_01 = np.sum(mask * prediction)
|
31 |
+
a_11 = np.sum(mask)
|
32 |
+
|
33 |
+
# right hand side: b = [b_0, b_1]
|
34 |
+
b_0 = np.sum(mask * prediction * target)
|
35 |
+
|
36 |
+
x_0 = b_0 / (a_00 + 1e-6)
|
37 |
+
|
38 |
+
return x_0
|
39 |
+
|
40 |
+
def compute_scale_and_shift_full(prediction, target, mask):
|
41 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
42 |
+
prediction = prediction.astype(np.float32)
|
43 |
+
target = target.astype(np.float32)
|
44 |
+
mask = mask.astype(np.float32)
|
45 |
+
|
46 |
+
a_00 = np.sum(mask * prediction * prediction)
|
47 |
+
a_01 = np.sum(mask * prediction)
|
48 |
+
a_11 = np.sum(mask)
|
49 |
+
|
50 |
+
b_0 = np.sum(mask * prediction * target)
|
51 |
+
b_1 = np.sum(mask * target)
|
52 |
+
|
53 |
+
x_0 = 1
|
54 |
+
x_1 = 0
|
55 |
+
|
56 |
+
det = a_00 * a_11 - a_01 * a_01
|
57 |
+
|
58 |
+
if det != 0:
|
59 |
+
x_0 = (a_11 * b_0 - a_01 * b_1) / det
|
60 |
+
x_1 = (-a_01 * b_0 + a_00 * b_1) / det
|
61 |
+
|
62 |
+
return x_0, x_1
|
63 |
+
|
64 |
+
|
65 |
+
def get_interpolate_frames(frame_list_pre, frame_list_post):
|
66 |
+
assert len(frame_list_pre) == len(frame_list_post)
|
67 |
+
min_w = 0.0
|
68 |
+
max_w = 1.0
|
69 |
+
step = (max_w - min_w) / (len(frame_list_pre)-1)
|
70 |
+
post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
|
71 |
+
interpolated_frames = []
|
72 |
+
for i in range(len(frame_list_pre)):
|
73 |
+
interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
|
74 |
+
return interpolated_frames
|
video_depth_anything/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
video_depth_anything/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
video_depth_anything/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
video_depth_anything/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
video_depth_anything/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
video_depth_anything/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
video_depth_anything/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
video_depth_anything/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
video_depth_anything/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
video_depth_anything/dpt.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
19 |
+
|
20 |
+
|
21 |
+
def _make_fusion_block(features, use_bn, size=None):
|
22 |
+
return FeatureFusionBlock(
|
23 |
+
features,
|
24 |
+
nn.ReLU(False),
|
25 |
+
deconv=False,
|
26 |
+
bn=use_bn,
|
27 |
+
expand=False,
|
28 |
+
align_corners=True,
|
29 |
+
size=size,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class ConvBlock(nn.Module):
|
34 |
+
def __init__(self, in_feature, out_feature):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.conv_block = nn.Sequential(
|
38 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
39 |
+
nn.BatchNorm2d(out_feature),
|
40 |
+
nn.ReLU(True)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.conv_block(x)
|
45 |
+
|
46 |
+
|
47 |
+
class DPTHead(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_channels,
|
51 |
+
features=256,
|
52 |
+
use_bn=False,
|
53 |
+
out_channels=[256, 512, 1024, 1024],
|
54 |
+
use_clstoken=False
|
55 |
+
):
|
56 |
+
super(DPTHead, self).__init__()
|
57 |
+
|
58 |
+
self.use_clstoken = use_clstoken
|
59 |
+
|
60 |
+
self.projects = nn.ModuleList([
|
61 |
+
nn.Conv2d(
|
62 |
+
in_channels=in_channels,
|
63 |
+
out_channels=out_channel,
|
64 |
+
kernel_size=1,
|
65 |
+
stride=1,
|
66 |
+
padding=0,
|
67 |
+
) for out_channel in out_channels
|
68 |
+
])
|
69 |
+
|
70 |
+
self.resize_layers = nn.ModuleList([
|
71 |
+
nn.ConvTranspose2d(
|
72 |
+
in_channels=out_channels[0],
|
73 |
+
out_channels=out_channels[0],
|
74 |
+
kernel_size=4,
|
75 |
+
stride=4,
|
76 |
+
padding=0),
|
77 |
+
nn.ConvTranspose2d(
|
78 |
+
in_channels=out_channels[1],
|
79 |
+
out_channels=out_channels[1],
|
80 |
+
kernel_size=2,
|
81 |
+
stride=2,
|
82 |
+
padding=0),
|
83 |
+
nn.Identity(),
|
84 |
+
nn.Conv2d(
|
85 |
+
in_channels=out_channels[3],
|
86 |
+
out_channels=out_channels[3],
|
87 |
+
kernel_size=3,
|
88 |
+
stride=2,
|
89 |
+
padding=1)
|
90 |
+
])
|
91 |
+
|
92 |
+
if use_clstoken:
|
93 |
+
self.readout_projects = nn.ModuleList()
|
94 |
+
for _ in range(len(self.projects)):
|
95 |
+
self.readout_projects.append(
|
96 |
+
nn.Sequential(
|
97 |
+
nn.Linear(2 * in_channels, in_channels),
|
98 |
+
nn.GELU()))
|
99 |
+
|
100 |
+
self.scratch = _make_scratch(
|
101 |
+
out_channels,
|
102 |
+
features,
|
103 |
+
groups=1,
|
104 |
+
expand=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
self.scratch.stem_transpose = None
|
108 |
+
|
109 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
110 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
111 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
112 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
113 |
+
|
114 |
+
head_features_1 = features
|
115 |
+
head_features_2 = 32
|
116 |
+
|
117 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
118 |
+
self.scratch.output_conv2 = nn.Sequential(
|
119 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
120 |
+
nn.ReLU(True),
|
121 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
122 |
+
nn.ReLU(True),
|
123 |
+
nn.Identity(),
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, out_features, patch_h, patch_w):
|
127 |
+
out = []
|
128 |
+
for i, x in enumerate(out_features):
|
129 |
+
if self.use_clstoken:
|
130 |
+
x, cls_token = x[0], x[1]
|
131 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
132 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
133 |
+
else:
|
134 |
+
x = x[0]
|
135 |
+
|
136 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
137 |
+
|
138 |
+
x = self.projects[i](x)
|
139 |
+
x = self.resize_layers[i](x)
|
140 |
+
|
141 |
+
out.append(x)
|
142 |
+
|
143 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
144 |
+
|
145 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
146 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
147 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
148 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
149 |
+
|
150 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
151 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
152 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
153 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
154 |
+
|
155 |
+
out = self.scratch.output_conv1(path_1)
|
156 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
157 |
+
out = self.scratch.output_conv2(out)
|
158 |
+
|
159 |
+
return out
|
160 |
+
|
video_depth_anything/dpt_temporal.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.nn as nn
|
17 |
+
from .dpt import DPTHead
|
18 |
+
from .motion_module.motion_module import TemporalModule
|
19 |
+
from easydict import EasyDict
|
20 |
+
|
21 |
+
|
22 |
+
class DPTHeadTemporal(DPTHead):
|
23 |
+
def __init__(self,
|
24 |
+
in_channels,
|
25 |
+
features=256,
|
26 |
+
use_bn=False,
|
27 |
+
out_channels=[256, 512, 1024, 1024],
|
28 |
+
use_clstoken=False,
|
29 |
+
num_frames=32,
|
30 |
+
pe='ape'
|
31 |
+
):
|
32 |
+
super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
|
33 |
+
|
34 |
+
assert num_frames > 0
|
35 |
+
motion_module_kwargs = EasyDict(num_attention_heads = 8,
|
36 |
+
num_transformer_block = 1,
|
37 |
+
num_attention_blocks = 2,
|
38 |
+
temporal_max_len = num_frames,
|
39 |
+
zero_initialize = True,
|
40 |
+
pos_embedding_type = pe)
|
41 |
+
|
42 |
+
self.motion_modules = nn.ModuleList([
|
43 |
+
TemporalModule(in_channels=out_channels[2],
|
44 |
+
**motion_module_kwargs),
|
45 |
+
TemporalModule(in_channels=out_channels[3],
|
46 |
+
**motion_module_kwargs),
|
47 |
+
TemporalModule(in_channels=features,
|
48 |
+
**motion_module_kwargs),
|
49 |
+
TemporalModule(in_channels=features,
|
50 |
+
**motion_module_kwargs)
|
51 |
+
])
|
52 |
+
|
53 |
+
def forward(self, out_features, patch_h, patch_w, frame_length):
|
54 |
+
out = []
|
55 |
+
for i, x in enumerate(out_features):
|
56 |
+
if self.use_clstoken:
|
57 |
+
x, cls_token = x[0], x[1]
|
58 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
59 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
60 |
+
else:
|
61 |
+
x = x[0]
|
62 |
+
|
63 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
|
64 |
+
|
65 |
+
B, T = x.shape[0] // frame_length, frame_length
|
66 |
+
x = self.projects[i](x)
|
67 |
+
x = self.resize_layers[i](x)
|
68 |
+
|
69 |
+
out.append(x)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
72 |
+
|
73 |
+
B, T = layer_1.shape[0] // frame_length, frame_length
|
74 |
+
|
75 |
+
layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
76 |
+
layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
77 |
+
|
78 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
79 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
80 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
81 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
82 |
+
|
83 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
84 |
+
path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
85 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
86 |
+
path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
87 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
88 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
89 |
+
|
90 |
+
out = self.scratch.output_conv1(path_1)
|
91 |
+
out = F.interpolate(
|
92 |
+
out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
|
93 |
+
)
|
94 |
+
out = self.scratch.output_conv2(out)
|
95 |
+
|
96 |
+
return out
|
video_depth_anything/motion_module/attention.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
import xformers
|
21 |
+
import xformers.ops
|
22 |
+
|
23 |
+
|
24 |
+
class CrossAttention(nn.Module):
|
25 |
+
r"""
|
26 |
+
A cross attention layer.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
query_dim (`int`): The number of channels in the query.
|
30 |
+
cross_attention_dim (`int`, *optional*):
|
31 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
32 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
33 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
34 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
35 |
+
bias (`bool`, *optional*, defaults to False):
|
36 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
query_dim: int,
|
42 |
+
cross_attention_dim: Optional[int] = None,
|
43 |
+
heads: int = 8,
|
44 |
+
dim_head: int = 64,
|
45 |
+
dropout: float = 0.0,
|
46 |
+
bias=False,
|
47 |
+
upcast_attention: bool = False,
|
48 |
+
upcast_softmax: bool = False,
|
49 |
+
added_kv_proj_dim: Optional[int] = None,
|
50 |
+
norm_num_groups: Optional[int] = None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
inner_dim = dim_head * heads
|
54 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
55 |
+
self.upcast_attention = upcast_attention
|
56 |
+
self.upcast_softmax = upcast_softmax
|
57 |
+
self.upcast_efficient_attention = False
|
58 |
+
|
59 |
+
self.scale = dim_head**-0.5
|
60 |
+
|
61 |
+
self.heads = heads
|
62 |
+
# for slice_size > 0 the attention score computation
|
63 |
+
# is split across the batch axis to save memory
|
64 |
+
# You can set slice_size with `set_attention_slice`
|
65 |
+
self.sliceable_head_dim = heads
|
66 |
+
self._slice_size = None
|
67 |
+
self._use_memory_efficient_attention_xformers = False
|
68 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
69 |
+
|
70 |
+
if norm_num_groups is not None:
|
71 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
72 |
+
else:
|
73 |
+
self.group_norm = None
|
74 |
+
|
75 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
76 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
77 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
78 |
+
|
79 |
+
if self.added_kv_proj_dim is not None:
|
80 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
81 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
82 |
+
|
83 |
+
self.to_out = nn.ModuleList([])
|
84 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
85 |
+
self.to_out.append(nn.Dropout(dropout))
|
86 |
+
|
87 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
88 |
+
batch_size, seq_len, dim = tensor.shape
|
89 |
+
head_size = self.heads
|
90 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
91 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
|
92 |
+
return tensor
|
93 |
+
|
94 |
+
def reshape_heads_to_4d(self, tensor):
|
95 |
+
batch_size, seq_len, dim = tensor.shape
|
96 |
+
head_size = self.heads
|
97 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
98 |
+
return tensor
|
99 |
+
|
100 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
101 |
+
batch_size, seq_len, dim = tensor.shape
|
102 |
+
head_size = self.heads
|
103 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
|
104 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
|
105 |
+
return tensor
|
106 |
+
|
107 |
+
def reshape_4d_to_heads(self, tensor):
|
108 |
+
batch_size, seq_len, head_size, dim = tensor.shape
|
109 |
+
head_size = self.heads
|
110 |
+
tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
|
111 |
+
return tensor
|
112 |
+
|
113 |
+
def set_attention_slice(self, slice_size):
|
114 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
115 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
116 |
+
|
117 |
+
self._slice_size = slice_size
|
118 |
+
|
119 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
120 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
121 |
+
|
122 |
+
encoder_hidden_states = encoder_hidden_states
|
123 |
+
|
124 |
+
if self.group_norm is not None:
|
125 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
126 |
+
|
127 |
+
query = self.to_q(hidden_states)
|
128 |
+
dim = query.shape[-1]
|
129 |
+
query = self.reshape_heads_to_batch_dim(query)
|
130 |
+
|
131 |
+
if self.added_kv_proj_dim is not None:
|
132 |
+
key = self.to_k(hidden_states)
|
133 |
+
value = self.to_v(hidden_states)
|
134 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
135 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
136 |
+
|
137 |
+
key = self.reshape_heads_to_batch_dim(key)
|
138 |
+
value = self.reshape_heads_to_batch_dim(value)
|
139 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
140 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
141 |
+
|
142 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
143 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
144 |
+
else:
|
145 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
146 |
+
key = self.to_k(encoder_hidden_states)
|
147 |
+
value = self.to_v(encoder_hidden_states)
|
148 |
+
|
149 |
+
key = self.reshape_heads_to_batch_dim(key)
|
150 |
+
value = self.reshape_heads_to_batch_dim(value)
|
151 |
+
|
152 |
+
if attention_mask is not None:
|
153 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
154 |
+
target_length = query.shape[1]
|
155 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
156 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
157 |
+
|
158 |
+
# attention, what we cannot get enough of
|
159 |
+
if self._use_memory_efficient_attention_xformers:
|
160 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
161 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
162 |
+
hidden_states = hidden_states.to(query.dtype)
|
163 |
+
else:
|
164 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
165 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
166 |
+
else:
|
167 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
168 |
+
|
169 |
+
# linear proj
|
170 |
+
hidden_states = self.to_out[0](hidden_states)
|
171 |
+
|
172 |
+
# dropout
|
173 |
+
hidden_states = self.to_out[1](hidden_states)
|
174 |
+
return hidden_states
|
175 |
+
|
176 |
+
def _attention(self, query, key, value, attention_mask=None):
|
177 |
+
if self.upcast_attention:
|
178 |
+
query = query.float()
|
179 |
+
key = key.float()
|
180 |
+
|
181 |
+
attention_scores = torch.baddbmm(
|
182 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
183 |
+
query,
|
184 |
+
key.transpose(-1, -2),
|
185 |
+
beta=0,
|
186 |
+
alpha=self.scale,
|
187 |
+
)
|
188 |
+
|
189 |
+
if attention_mask is not None:
|
190 |
+
attention_scores = attention_scores + attention_mask
|
191 |
+
|
192 |
+
if self.upcast_softmax:
|
193 |
+
attention_scores = attention_scores.float()
|
194 |
+
|
195 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
196 |
+
|
197 |
+
# cast back to the original dtype
|
198 |
+
attention_probs = attention_probs.to(value.dtype)
|
199 |
+
|
200 |
+
# compute attention output
|
201 |
+
hidden_states = torch.bmm(attention_probs, value)
|
202 |
+
|
203 |
+
# reshape hidden_states
|
204 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
205 |
+
return hidden_states
|
206 |
+
|
207 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
208 |
+
batch_size_attention = query.shape[0]
|
209 |
+
hidden_states = torch.zeros(
|
210 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
211 |
+
)
|
212 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
213 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
214 |
+
start_idx = i * slice_size
|
215 |
+
end_idx = (i + 1) * slice_size
|
216 |
+
|
217 |
+
query_slice = query[start_idx:end_idx]
|
218 |
+
key_slice = key[start_idx:end_idx]
|
219 |
+
|
220 |
+
if self.upcast_attention:
|
221 |
+
query_slice = query_slice.float()
|
222 |
+
key_slice = key_slice.float()
|
223 |
+
|
224 |
+
attn_slice = torch.baddbmm(
|
225 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
226 |
+
query_slice,
|
227 |
+
key_slice.transpose(-1, -2),
|
228 |
+
beta=0,
|
229 |
+
alpha=self.scale,
|
230 |
+
)
|
231 |
+
|
232 |
+
if attention_mask is not None:
|
233 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
234 |
+
|
235 |
+
if self.upcast_softmax:
|
236 |
+
attn_slice = attn_slice.float()
|
237 |
+
|
238 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
239 |
+
|
240 |
+
# cast back to the original dtype
|
241 |
+
attn_slice = attn_slice.to(value.dtype)
|
242 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
243 |
+
|
244 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
245 |
+
|
246 |
+
# reshape hidden_states
|
247 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
248 |
+
return hidden_states
|
249 |
+
|
250 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
251 |
+
if self.upcast_efficient_attention:
|
252 |
+
org_dtype = query.dtype
|
253 |
+
query = query.float()
|
254 |
+
key = key.float()
|
255 |
+
value = value.float()
|
256 |
+
if attention_mask is not None:
|
257 |
+
attention_mask = attention_mask.float()
|
258 |
+
hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
|
259 |
+
|
260 |
+
if self.upcast_efficient_attention:
|
261 |
+
hidden_states = hidden_states.to(org_dtype)
|
262 |
+
|
263 |
+
hidden_states = self.reshape_4d_to_heads(hidden_states)
|
264 |
+
return hidden_states
|
265 |
+
|
266 |
+
# print("Errror: no xformers")
|
267 |
+
# raise NotImplementedError
|
268 |
+
|
269 |
+
def _memory_efficient_attention_split(self, query, key, value, attention_mask):
|
270 |
+
batch_size = query.shape[0]
|
271 |
+
max_batch_size = 65535
|
272 |
+
num_batches = (batch_size + max_batch_size - 1) // max_batch_size
|
273 |
+
results = []
|
274 |
+
for i in range(num_batches):
|
275 |
+
start_idx = i * max_batch_size
|
276 |
+
end_idx = min((i + 1) * max_batch_size, batch_size)
|
277 |
+
query_batch = query[start_idx:end_idx]
|
278 |
+
key_batch = key[start_idx:end_idx]
|
279 |
+
value_batch = value[start_idx:end_idx]
|
280 |
+
if attention_mask is not None:
|
281 |
+
attention_mask_batch = attention_mask[start_idx:end_idx]
|
282 |
+
else:
|
283 |
+
attention_mask_batch = None
|
284 |
+
result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
|
285 |
+
results.append(result)
|
286 |
+
full_result = torch.cat(results, dim=0)
|
287 |
+
return full_result
|
288 |
+
|
289 |
+
|
290 |
+
class FeedForward(nn.Module):
|
291 |
+
r"""
|
292 |
+
A feed-forward layer.
|
293 |
+
|
294 |
+
Parameters:
|
295 |
+
dim (`int`): The number of channels in the input.
|
296 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
297 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
298 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
299 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
dim: int,
|
305 |
+
dim_out: Optional[int] = None,
|
306 |
+
mult: int = 4,
|
307 |
+
dropout: float = 0.0,
|
308 |
+
activation_fn: str = "geglu",
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
inner_dim = int(dim * mult)
|
312 |
+
dim_out = dim_out if dim_out is not None else dim
|
313 |
+
|
314 |
+
if activation_fn == "gelu":
|
315 |
+
act_fn = GELU(dim, inner_dim)
|
316 |
+
elif activation_fn == "geglu":
|
317 |
+
act_fn = GEGLU(dim, inner_dim)
|
318 |
+
elif activation_fn == "geglu-approximate":
|
319 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
320 |
+
|
321 |
+
self.net = nn.ModuleList([])
|
322 |
+
# project in
|
323 |
+
self.net.append(act_fn)
|
324 |
+
# project dropout
|
325 |
+
self.net.append(nn.Dropout(dropout))
|
326 |
+
# project out
|
327 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
328 |
+
|
329 |
+
def forward(self, hidden_states):
|
330 |
+
for module in self.net:
|
331 |
+
hidden_states = module(hidden_states)
|
332 |
+
return hidden_states
|
333 |
+
|
334 |
+
|
335 |
+
class GELU(nn.Module):
|
336 |
+
r"""
|
337 |
+
GELU activation function
|
338 |
+
"""
|
339 |
+
|
340 |
+
def __init__(self, dim_in: int, dim_out: int):
|
341 |
+
super().__init__()
|
342 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
343 |
+
|
344 |
+
def gelu(self, gate):
|
345 |
+
if gate.device.type != "mps":
|
346 |
+
return F.gelu(gate)
|
347 |
+
# mps: gelu is not implemented for float16
|
348 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
349 |
+
|
350 |
+
def forward(self, hidden_states):
|
351 |
+
hidden_states = self.proj(hidden_states)
|
352 |
+
hidden_states = self.gelu(hidden_states)
|
353 |
+
return hidden_states
|
354 |
+
|
355 |
+
|
356 |
+
# feedforward
|
357 |
+
class GEGLU(nn.Module):
|
358 |
+
r"""
|
359 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
360 |
+
|
361 |
+
Parameters:
|
362 |
+
dim_in (`int`): The number of channels in the input.
|
363 |
+
dim_out (`int`): The number of channels in the output.
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self, dim_in: int, dim_out: int):
|
367 |
+
super().__init__()
|
368 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
369 |
+
|
370 |
+
def gelu(self, gate):
|
371 |
+
if gate.device.type != "mps":
|
372 |
+
return F.gelu(gate)
|
373 |
+
# mps: gelu is not implemented for float16
|
374 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
375 |
+
|
376 |
+
def forward(self, hidden_states):
|
377 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
378 |
+
return hidden_states * self.gelu(gate)
|
379 |
+
|
380 |
+
|
381 |
+
class ApproximateGELU(nn.Module):
|
382 |
+
"""
|
383 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
384 |
+
|
385 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(self, dim_in: int, dim_out: int):
|
389 |
+
super().__init__()
|
390 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
391 |
+
|
392 |
+
def forward(self, x):
|
393 |
+
x = self.proj(x)
|
394 |
+
return x * torch.sigmoid(1.702 * x)
|
395 |
+
|
396 |
+
|
397 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
398 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
399 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
400 |
+
freqs = torch.outer(t, freqs)
|
401 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
402 |
+
return freqs_cis
|
403 |
+
|
404 |
+
|
405 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
406 |
+
ndim = x.ndim
|
407 |
+
assert 0 <= 1 < ndim
|
408 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
409 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
410 |
+
return freqs_cis.view(*shape)
|
411 |
+
|
412 |
+
|
413 |
+
def apply_rotary_emb(
|
414 |
+
xq: torch.Tensor,
|
415 |
+
xk: torch.Tensor,
|
416 |
+
freqs_cis: torch.Tensor,
|
417 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
418 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
|
419 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
|
420 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
421 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
|
422 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
|
423 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
video_depth_anything/motion_module/motion_module.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
|
2 |
+
# SPDX-License-Identifier: Apache-2.0 license
|
3 |
+
#
|
4 |
+
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
|
5 |
+
# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
|
11 |
+
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
import math
|
14 |
+
|
15 |
+
|
16 |
+
def zero_module(module):
|
17 |
+
# Zero out the parameters of a module and return it.
|
18 |
+
for p in module.parameters():
|
19 |
+
p.detach().zero_()
|
20 |
+
return module
|
21 |
+
|
22 |
+
|
23 |
+
class TemporalModule(nn.Module):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
in_channels,
|
27 |
+
num_attention_heads = 8,
|
28 |
+
num_transformer_block = 2,
|
29 |
+
num_attention_blocks = 2,
|
30 |
+
norm_num_groups = 32,
|
31 |
+
temporal_max_len = 32,
|
32 |
+
zero_initialize = True,
|
33 |
+
pos_embedding_type = "ape",
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
38 |
+
in_channels=in_channels,
|
39 |
+
num_attention_heads=num_attention_heads,
|
40 |
+
attention_head_dim=in_channels // num_attention_heads,
|
41 |
+
num_layers=num_transformer_block,
|
42 |
+
num_attention_blocks=num_attention_blocks,
|
43 |
+
norm_num_groups=norm_num_groups,
|
44 |
+
temporal_max_len=temporal_max_len,
|
45 |
+
pos_embedding_type=pos_embedding_type,
|
46 |
+
)
|
47 |
+
|
48 |
+
if zero_initialize:
|
49 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
50 |
+
|
51 |
+
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
|
52 |
+
hidden_states = input_tensor
|
53 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
54 |
+
|
55 |
+
output = hidden_states
|
56 |
+
return output
|
57 |
+
|
58 |
+
|
59 |
+
class TemporalTransformer3DModel(nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_channels,
|
63 |
+
num_attention_heads,
|
64 |
+
attention_head_dim,
|
65 |
+
num_layers,
|
66 |
+
num_attention_blocks = 2,
|
67 |
+
norm_num_groups = 32,
|
68 |
+
temporal_max_len = 32,
|
69 |
+
pos_embedding_type = "ape",
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
inner_dim = num_attention_heads * attention_head_dim
|
74 |
+
|
75 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
76 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
77 |
+
|
78 |
+
self.transformer_blocks = nn.ModuleList(
|
79 |
+
[
|
80 |
+
TemporalTransformerBlock(
|
81 |
+
dim=inner_dim,
|
82 |
+
num_attention_heads=num_attention_heads,
|
83 |
+
attention_head_dim=attention_head_dim,
|
84 |
+
num_attention_blocks=num_attention_blocks,
|
85 |
+
temporal_max_len=temporal_max_len,
|
86 |
+
pos_embedding_type=pos_embedding_type,
|
87 |
+
)
|
88 |
+
for d in range(num_layers)
|
89 |
+
]
|
90 |
+
)
|
91 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
92 |
+
|
93 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
94 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
95 |
+
video_length = hidden_states.shape[2]
|
96 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
97 |
+
|
98 |
+
batch, channel, height, width = hidden_states.shape
|
99 |
+
residual = hidden_states
|
100 |
+
|
101 |
+
hidden_states = self.norm(hidden_states)
|
102 |
+
inner_dim = hidden_states.shape[1]
|
103 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
|
104 |
+
hidden_states = self.proj_in(hidden_states)
|
105 |
+
|
106 |
+
# Transformer Blocks
|
107 |
+
for block in self.transformer_blocks:
|
108 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
|
109 |
+
|
110 |
+
# output
|
111 |
+
hidden_states = self.proj_out(hidden_states)
|
112 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
113 |
+
|
114 |
+
output = hidden_states + residual
|
115 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
116 |
+
|
117 |
+
return output
|
118 |
+
|
119 |
+
|
120 |
+
class TemporalTransformerBlock(nn.Module):
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
dim,
|
124 |
+
num_attention_heads,
|
125 |
+
attention_head_dim,
|
126 |
+
num_attention_blocks = 2,
|
127 |
+
temporal_max_len = 32,
|
128 |
+
pos_embedding_type = "ape",
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
self.attention_blocks = nn.ModuleList(
|
133 |
+
[
|
134 |
+
TemporalAttention(
|
135 |
+
query_dim=dim,
|
136 |
+
heads=num_attention_heads,
|
137 |
+
dim_head=attention_head_dim,
|
138 |
+
temporal_max_len=temporal_max_len,
|
139 |
+
pos_embedding_type=pos_embedding_type,
|
140 |
+
)
|
141 |
+
for i in range(num_attention_blocks)
|
142 |
+
]
|
143 |
+
)
|
144 |
+
self.norms = nn.ModuleList(
|
145 |
+
[
|
146 |
+
nn.LayerNorm(dim)
|
147 |
+
for i in range(num_attention_blocks)
|
148 |
+
]
|
149 |
+
)
|
150 |
+
|
151 |
+
self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
|
152 |
+
self.ff_norm = nn.LayerNorm(dim)
|
153 |
+
|
154 |
+
|
155 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
156 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
157 |
+
norm_hidden_states = norm(hidden_states)
|
158 |
+
hidden_states = attention_block(
|
159 |
+
norm_hidden_states,
|
160 |
+
encoder_hidden_states=encoder_hidden_states,
|
161 |
+
video_length=video_length,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
) + hidden_states
|
164 |
+
|
165 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
166 |
+
|
167 |
+
output = hidden_states
|
168 |
+
return output
|
169 |
+
|
170 |
+
|
171 |
+
class PositionalEncoding(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
d_model,
|
175 |
+
dropout = 0.,
|
176 |
+
max_len = 32
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.dropout = nn.Dropout(p=dropout)
|
180 |
+
position = torch.arange(max_len).unsqueeze(1)
|
181 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
182 |
+
pe = torch.zeros(1, max_len, d_model)
|
183 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
184 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
185 |
+
self.register_buffer('pe', pe)
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
x = x + self.pe[:, :x.size(1)].to(x.dtype)
|
189 |
+
return self.dropout(x)
|
190 |
+
|
191 |
+
class TemporalAttention(CrossAttention):
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
temporal_max_len = 32,
|
195 |
+
pos_embedding_type = "ape",
|
196 |
+
*args, **kwargs
|
197 |
+
):
|
198 |
+
super().__init__(*args, **kwargs)
|
199 |
+
|
200 |
+
self.pos_embedding_type = pos_embedding_type
|
201 |
+
self._use_memory_efficient_attention_xformers = True
|
202 |
+
|
203 |
+
self.pos_encoder = None
|
204 |
+
self.freqs_cis = None
|
205 |
+
if self.pos_embedding_type == "ape":
|
206 |
+
self.pos_encoder = PositionalEncoding(
|
207 |
+
kwargs["query_dim"],
|
208 |
+
dropout=0.,
|
209 |
+
max_len=temporal_max_len
|
210 |
+
)
|
211 |
+
|
212 |
+
elif self.pos_embedding_type == "rope":
|
213 |
+
self.freqs_cis = precompute_freqs_cis(
|
214 |
+
kwargs["query_dim"],
|
215 |
+
temporal_max_len
|
216 |
+
)
|
217 |
+
|
218 |
+
else:
|
219 |
+
raise NotImplementedError
|
220 |
+
|
221 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
222 |
+
d = hidden_states.shape[1]
|
223 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
224 |
+
|
225 |
+
if self.pos_encoder is not None:
|
226 |
+
hidden_states = self.pos_encoder(hidden_states)
|
227 |
+
|
228 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
229 |
+
|
230 |
+
if self.group_norm is not None:
|
231 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
232 |
+
|
233 |
+
query = self.to_q(hidden_states)
|
234 |
+
dim = query.shape[-1]
|
235 |
+
|
236 |
+
if self.added_kv_proj_dim is not None:
|
237 |
+
raise NotImplementedError
|
238 |
+
|
239 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
240 |
+
key = self.to_k(encoder_hidden_states)
|
241 |
+
value = self.to_v(encoder_hidden_states)
|
242 |
+
|
243 |
+
if self.freqs_cis is not None:
|
244 |
+
seq_len = query.shape[1]
|
245 |
+
freqs_cis = self.freqs_cis[:seq_len].to(query.device)
|
246 |
+
query, key = apply_rotary_emb(query, key, freqs_cis)
|
247 |
+
|
248 |
+
if attention_mask is not None:
|
249 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
250 |
+
target_length = query.shape[1]
|
251 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
252 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
253 |
+
|
254 |
+
|
255 |
+
use_memory_efficient = self._use_memory_efficient_attention_xformers
|
256 |
+
if use_memory_efficient and (dim // self.heads) % 8 != 0:
|
257 |
+
# print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
|
258 |
+
use_memory_efficient = False
|
259 |
+
|
260 |
+
# attention, what we cannot get enough of
|
261 |
+
if use_memory_efficient:
|
262 |
+
query = self.reshape_heads_to_4d(query)
|
263 |
+
key = self.reshape_heads_to_4d(key)
|
264 |
+
value = self.reshape_heads_to_4d(value)
|
265 |
+
|
266 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
267 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
268 |
+
hidden_states = hidden_states.to(query.dtype)
|
269 |
+
else:
|
270 |
+
query = self.reshape_heads_to_batch_dim(query)
|
271 |
+
key = self.reshape_heads_to_batch_dim(key)
|
272 |
+
value = self.reshape_heads_to_batch_dim(value)
|
273 |
+
|
274 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
275 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
276 |
+
else:
|
277 |
+
raise NotImplementedError
|
278 |
+
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
279 |
+
|
280 |
+
# linear proj
|
281 |
+
hidden_states = self.to_out[0](hidden_states)
|
282 |
+
|
283 |
+
# dropout
|
284 |
+
hidden_states = self.to_out[1](hidden_states)
|
285 |
+
|
286 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
287 |
+
|
288 |
+
return hidden_states
|
video_depth_anything/util/blocks.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape * 2
|
16 |
+
out_shape3 = out_shape * 4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape * 8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(
|
21 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
22 |
+
)
|
23 |
+
scratch.layer2_rn = nn.Conv2d(
|
24 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
25 |
+
)
|
26 |
+
scratch.layer3_rn = nn.Conv2d(
|
27 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
28 |
+
)
|
29 |
+
if len(in_shape) >= 4:
|
30 |
+
scratch.layer4_rn = nn.Conv2d(
|
31 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
32 |
+
)
|
33 |
+
|
34 |
+
return scratch
|
35 |
+
|
36 |
+
|
37 |
+
class ResidualConvUnit(nn.Module):
|
38 |
+
"""Residual convolution module."""
|
39 |
+
|
40 |
+
def __init__(self, features, activation, bn):
|
41 |
+
"""Init.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
features (int): number of features
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.bn = bn
|
49 |
+
|
50 |
+
self.groups = 1
|
51 |
+
|
52 |
+
self.conv1 = nn.Conv2d(
|
53 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
54 |
+
)
|
55 |
+
|
56 |
+
self.conv2 = nn.Conv2d(
|
57 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
58 |
+
)
|
59 |
+
|
60 |
+
if self.bn is True:
|
61 |
+
self.bn1 = nn.BatchNorm2d(features)
|
62 |
+
self.bn2 = nn.BatchNorm2d(features)
|
63 |
+
|
64 |
+
self.activation = activation
|
65 |
+
|
66 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
"""Forward pass.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
x (tensor): input
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tensor: output
|
76 |
+
"""
|
77 |
+
|
78 |
+
out = self.activation(x)
|
79 |
+
out = self.conv1(out)
|
80 |
+
if self.bn is True:
|
81 |
+
out = self.bn1(out)
|
82 |
+
|
83 |
+
out = self.activation(out)
|
84 |
+
out = self.conv2(out)
|
85 |
+
if self.bn is True:
|
86 |
+
out = self.bn2(out)
|
87 |
+
|
88 |
+
if self.groups > 1:
|
89 |
+
out = self.conv_merge(out)
|
90 |
+
|
91 |
+
return self.skip_add.add(out, x)
|
92 |
+
|
93 |
+
|
94 |
+
class FeatureFusionBlock(nn.Module):
|
95 |
+
"""Feature fusion block."""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
features,
|
100 |
+
activation,
|
101 |
+
deconv=False,
|
102 |
+
bn=False,
|
103 |
+
expand=False,
|
104 |
+
align_corners=True,
|
105 |
+
size=None,
|
106 |
+
):
|
107 |
+
"""Init.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
features (int): number of features
|
111 |
+
"""
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.deconv = deconv
|
115 |
+
self.align_corners = align_corners
|
116 |
+
|
117 |
+
self.groups = 1
|
118 |
+
|
119 |
+
self.expand = expand
|
120 |
+
out_features = features
|
121 |
+
if self.expand is True:
|
122 |
+
out_features = features // 2
|
123 |
+
|
124 |
+
self.out_conv = nn.Conv2d(
|
125 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
|
126 |
+
)
|
127 |
+
|
128 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
129 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
130 |
+
|
131 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
132 |
+
|
133 |
+
self.size = size
|
134 |
+
|
135 |
+
def forward(self, *xs, size=None):
|
136 |
+
"""Forward pass.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
tensor: output
|
140 |
+
"""
|
141 |
+
output = xs[0]
|
142 |
+
|
143 |
+
if len(xs) == 2:
|
144 |
+
res = self.resConfUnit1(xs[1])
|
145 |
+
output = self.skip_add.add(output, res)
|
146 |
+
|
147 |
+
output = self.resConfUnit2(output)
|
148 |
+
|
149 |
+
if (size is None) and (self.size is None):
|
150 |
+
modifier = {"scale_factor": 2}
|
151 |
+
elif size is None:
|
152 |
+
modifier = {"size": self.size}
|
153 |
+
else:
|
154 |
+
modifier = {"size": size}
|
155 |
+
|
156 |
+
output = nn.functional.interpolate(
|
157 |
+
output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
|
158 |
+
)
|
159 |
+
|
160 |
+
output = self.out_conv(output)
|
161 |
+
|
162 |
+
return output
|
video_depth_anything/util/transform.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class Resize(object):
|
6 |
+
"""Resize sample to given size (width, height).
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
width,
|
12 |
+
height,
|
13 |
+
resize_target=True,
|
14 |
+
keep_aspect_ratio=False,
|
15 |
+
ensure_multiple_of=1,
|
16 |
+
resize_method="lower_bound",
|
17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
18 |
+
):
|
19 |
+
"""Init.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
width (int): desired output width
|
23 |
+
height (int): desired output height
|
24 |
+
resize_target (bool, optional):
|
25 |
+
True: Resize the full sample (image, mask, target).
|
26 |
+
False: Resize image only.
|
27 |
+
Defaults to True.
|
28 |
+
keep_aspect_ratio (bool, optional):
|
29 |
+
True: Keep the aspect ratio of the input sample.
|
30 |
+
Output sample might not have the given width and height, and
|
31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
32 |
+
Defaults to False.
|
33 |
+
ensure_multiple_of (int, optional):
|
34 |
+
Output width and height is constrained to be multiple of this parameter.
|
35 |
+
Defaults to 1.
|
36 |
+
resize_method (str, optional):
|
37 |
+
"lower_bound": Output will be at least as large as the given size.
|
38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
40 |
+
Defaults to "lower_bound".
|
41 |
+
"""
|
42 |
+
self.__width = width
|
43 |
+
self.__height = height
|
44 |
+
|
45 |
+
self.__resize_target = resize_target
|
46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
47 |
+
self.__multiple_of = ensure_multiple_of
|
48 |
+
self.__resize_method = resize_method
|
49 |
+
self.__image_interpolation_method = image_interpolation_method
|
50 |
+
|
51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
53 |
+
|
54 |
+
if max_val is not None and y > max_val:
|
55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
56 |
+
|
57 |
+
if y < min_val:
|
58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
def get_size(self, width, height):
|
63 |
+
# determine new height and width
|
64 |
+
scale_height = self.__height / height
|
65 |
+
scale_width = self.__width / width
|
66 |
+
|
67 |
+
if self.__keep_aspect_ratio:
|
68 |
+
if self.__resize_method == "lower_bound":
|
69 |
+
# scale such that output size is lower bound
|
70 |
+
if scale_width > scale_height:
|
71 |
+
# fit width
|
72 |
+
scale_height = scale_width
|
73 |
+
else:
|
74 |
+
# fit height
|
75 |
+
scale_width = scale_height
|
76 |
+
elif self.__resize_method == "upper_bound":
|
77 |
+
# scale such that output size is upper bound
|
78 |
+
if scale_width < scale_height:
|
79 |
+
# fit width
|
80 |
+
scale_height = scale_width
|
81 |
+
else:
|
82 |
+
# fit height
|
83 |
+
scale_width = scale_height
|
84 |
+
elif self.__resize_method == "minimal":
|
85 |
+
# scale as least as possbile
|
86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
87 |
+
# fit width
|
88 |
+
scale_height = scale_width
|
89 |
+
else:
|
90 |
+
# fit height
|
91 |
+
scale_width = scale_height
|
92 |
+
else:
|
93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
94 |
+
|
95 |
+
if self.__resize_method == "lower_bound":
|
96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
98 |
+
elif self.__resize_method == "upper_bound":
|
99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
101 |
+
elif self.__resize_method == "minimal":
|
102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
106 |
+
|
107 |
+
return (new_width, new_height)
|
108 |
+
|
109 |
+
def __call__(self, sample):
|
110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
111 |
+
|
112 |
+
# resize sample
|
113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
114 |
+
|
115 |
+
if self.__resize_target:
|
116 |
+
if "depth" in sample:
|
117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
118 |
+
|
119 |
+
if "mask" in sample:
|
120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
121 |
+
|
122 |
+
return sample
|
123 |
+
|
124 |
+
|
125 |
+
class NormalizeImage(object):
|
126 |
+
"""Normlize image by given mean and std.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.__mean = mean
|
131 |
+
self.__std = std
|
132 |
+
|
133 |
+
def __call__(self, sample):
|
134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
135 |
+
|
136 |
+
return sample
|
137 |
+
|
138 |
+
|
139 |
+
class PrepareForNet(object):
|
140 |
+
"""Prepare sample for usage as network input.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
def __call__(self, sample):
|
147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
149 |
+
|
150 |
+
if "depth" in sample:
|
151 |
+
depth = sample["depth"].astype(np.float32)
|
152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
153 |
+
|
154 |
+
if "mask" in sample:
|
155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
157 |
+
|
158 |
+
return sample
|
video_depth_anything/video_depth.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.nn as nn
|
17 |
+
from torchvision.transforms import Compose
|
18 |
+
import cv2
|
19 |
+
from tqdm import tqdm
|
20 |
+
import numpy as np
|
21 |
+
import gc
|
22 |
+
|
23 |
+
from .dinov2 import DINOv2
|
24 |
+
from .dpt_temporal import DPTHeadTemporal
|
25 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
26 |
+
|
27 |
+
from utils.util import compute_scale_and_shift, get_interpolate_frames
|
28 |
+
|
29 |
+
# infer settings, do not change
|
30 |
+
INFER_LEN = 32
|
31 |
+
OVERLAP = 10
|
32 |
+
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
|
33 |
+
INTERP_LEN = 8
|
34 |
+
|
35 |
+
class VideoDepthAnything(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
encoder='vitl',
|
39 |
+
features=256,
|
40 |
+
out_channels=[256, 512, 1024, 1024],
|
41 |
+
use_bn=False,
|
42 |
+
use_clstoken=False,
|
43 |
+
num_frames=32,
|
44 |
+
pe='ape'
|
45 |
+
):
|
46 |
+
super(VideoDepthAnything, self).__init__()
|
47 |
+
|
48 |
+
self.intermediate_layer_idx = {
|
49 |
+
'vits': [2, 5, 8, 11],
|
50 |
+
'vitl': [4, 11, 17, 23]
|
51 |
+
}
|
52 |
+
|
53 |
+
self.encoder = encoder
|
54 |
+
self.pretrained = DINOv2(model_name=encoder)
|
55 |
+
|
56 |
+
self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
B, T, C, H, W = x.shape
|
60 |
+
patch_h, patch_w = H // 14, W // 14
|
61 |
+
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
62 |
+
depth = self.head(features, patch_h, patch_w, T)
|
63 |
+
depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
|
64 |
+
depth = F.relu(depth)
|
65 |
+
return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
|
66 |
+
|
67 |
+
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
|
68 |
+
transform = Compose([
|
69 |
+
Resize(
|
70 |
+
width=input_size,
|
71 |
+
height=input_size,
|
72 |
+
resize_target=False,
|
73 |
+
keep_aspect_ratio=True,
|
74 |
+
ensure_multiple_of=14,
|
75 |
+
resize_method='lower_bound',
|
76 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
77 |
+
),
|
78 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
79 |
+
PrepareForNet(),
|
80 |
+
])
|
81 |
+
|
82 |
+
frame_size = frames[0].shape[:2]
|
83 |
+
frame_list = [frames[i] for i in range(frames.shape[0])]
|
84 |
+
frame_step = INFER_LEN - OVERLAP
|
85 |
+
org_video_len = len(frame_list)
|
86 |
+
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
|
87 |
+
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
|
88 |
+
|
89 |
+
depth_list = []
|
90 |
+
pre_input = None
|
91 |
+
for frame_id in tqdm(range(0, org_video_len, frame_step)):
|
92 |
+
cur_list = []
|
93 |
+
for i in range(INFER_LEN):
|
94 |
+
cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
|
95 |
+
cur_input = torch.cat(cur_list, dim=1).to(device)
|
96 |
+
if pre_input is not None:
|
97 |
+
cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
depth = self.forward(cur_input) # depth shape: [1, T, H, W]
|
101 |
+
|
102 |
+
depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=frame_size, mode='bilinear', align_corners=True)
|
103 |
+
depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
|
104 |
+
|
105 |
+
pre_input = cur_input
|
106 |
+
|
107 |
+
del frame_list
|
108 |
+
gc.collect()
|
109 |
+
|
110 |
+
depth_list_aligned = []
|
111 |
+
ref_align = []
|
112 |
+
align_len = OVERLAP - INTERP_LEN
|
113 |
+
kf_align_list = KEYFRAMES[:align_len]
|
114 |
+
|
115 |
+
for frame_id in range(0, len(depth_list), INFER_LEN):
|
116 |
+
if len(depth_list_aligned) == 0:
|
117 |
+
depth_list_aligned += depth_list[:INFER_LEN]
|
118 |
+
for kf_id in kf_align_list:
|
119 |
+
ref_align.append(depth_list[frame_id+kf_id])
|
120 |
+
else:
|
121 |
+
curr_align = []
|
122 |
+
for i in range(len(kf_align_list)):
|
123 |
+
curr_align.append(depth_list[frame_id+i])
|
124 |
+
scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
|
125 |
+
np.concatenate(ref_align),
|
126 |
+
np.concatenate(np.ones_like(ref_align)==1))
|
127 |
+
|
128 |
+
pre_depth_list = depth_list_aligned[-INTERP_LEN:]
|
129 |
+
post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
|
130 |
+
for i in range(len(post_depth_list)):
|
131 |
+
post_depth_list[i] = post_depth_list[i] * scale + shift
|
132 |
+
post_depth_list[i][post_depth_list[i]<0] = 0
|
133 |
+
depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
|
134 |
+
|
135 |
+
for i in range(OVERLAP, INFER_LEN):
|
136 |
+
new_depth = depth_list[frame_id+i] * scale + shift
|
137 |
+
new_depth[new_depth<0] = 0
|
138 |
+
depth_list_aligned.append(new_depth)
|
139 |
+
|
140 |
+
ref_align = ref_align[:1]
|
141 |
+
for kf_id in kf_align_list[1:]:
|
142 |
+
new_depth = depth_list[frame_id+kf_id] * scale + shift
|
143 |
+
new_depth[new_depth<0] = 0
|
144 |
+
ref_align.append(new_depth)
|
145 |
+
|
146 |
+
depth_list = depth_list_aligned
|
147 |
+
|
148 |
+
return depth_list[:org_video_len], target_fps
|
149 |
+
|