depthanyvideo commited on
Commit
4be2365
1 Parent(s): e9f3e75
Files changed (1) hide show
  1. app.py +187 -8
app.py CHANGED
@@ -1,15 +1,194 @@
1
  import gradio as gr
 
 
 
 
 
2
  import spaces
 
 
3
  import torch
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
7
 
8
- @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' 🤗
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import logging
3
+ import os
4
+ import random
5
+ import tempfile
6
+ import time
7
  import spaces
8
+ from easydict import EasyDict
9
+ import numpy as np
10
  import torch
11
+ from dav.pipelines import DAVPipeline
12
+ from dav.models import UNetSpatioTemporalRopeConditionModel
13
+ from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler
14
+ from dav.utils import img_utils
15
 
 
 
16
 
17
+ def seed_all(seed: int = 0):
18
+ """
19
+ Set random seeds for reproducibility.
20
+ """
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
 
 
 
26
 
27
+ # Initialize logging
28
+ logging.basicConfig(level=logging.INFO)
29
+
30
+
31
+ # Load models once to avoid reloading on every inference
32
+ def load_models(model_base, device):
33
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae")
34
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
35
+ model_base, subfolder="scheduler"
36
+ )
37
+ unet = UNetSpatioTemporalRopeConditionModel.from_pretrained(
38
+ model_base, subfolder="unet"
39
+ )
40
+ unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained(
41
+ model_base, subfolder="unet_interp"
42
+ )
43
+ pipe = DAVPipeline(
44
+ vae=vae,
45
+ unet=unet,
46
+ unet_interp=unet_interp,
47
+ scheduler=scheduler,
48
+ )
49
+ pipe = pipe.to(device)
50
+ return pipe
51
+
52
+
53
+ # Load models at startup
54
+ MODEL_BASE = "hhyangcs/depth-any-video"
55
+ DEVICE_TYPE = "cuda"
56
+ DEVICE = torch.device(DEVICE_TYPE)
57
+ pipe = load_models(MODEL_BASE, DEVICE)
58
+
59
+
60
+ @spaces.GPU(duration=140)
61
+ def depth_any_video(
62
+ file,
63
+ denoise_steps=3,
64
+ num_frames=32,
65
+ decode_chunk_size=16,
66
+ num_interp_frames=16,
67
+ num_overlap_frames=6,
68
+ max_resolution=1024,
69
+ ):
70
+ """
71
+ Perform depth estimation on the uploaded video/image.
72
+ """
73
+ with tempfile.TemporaryDirectory() as tmp_dir:
74
+ # Save the uploaded file
75
+ input_path = os.path.join(tmp_dir, file.name)
76
+ with open(input_path, "wb") as f:
77
+ f.write(file.read())
78
+
79
+ # Set up output directory
80
+ output_dir = os.path.join(tmp_dir, "output")
81
+ os.makedirs(output_dir, exist_ok=True)
82
+
83
+ # Prepare configuration
84
+ cfg = EasyDict(
85
+ {
86
+ "model_base": MODEL_BASE,
87
+ "data_path": input_path,
88
+ "output_dir": output_dir,
89
+ "denoise_steps": denoise_steps,
90
+ "num_frames": num_frames,
91
+ "decode_chunk_size": decode_chunk_size,
92
+ "num_interp_frames": num_interp_frames,
93
+ "num_overlap_frames": num_overlap_frames,
94
+ "max_resolution": max_resolution,
95
+ "seed": 666,
96
+ }
97
+ )
98
+
99
+ seed_all(cfg.seed)
100
+
101
+ file_name = os.path.splitext(os.path.basename(cfg.data_path))[0]
102
+ is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv"))
103
+
104
+ if is_video:
105
+ num_interp_frames = cfg.num_interp_frames
106
+ num_overlap_frames = cfg.num_overlap_frames
107
+ num_frames = cfg.num_frames
108
+ assert num_frames % 2 == 0, "num_frames should be even."
109
+ assert (
110
+ 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2
111
+ ), "Invalid frame overlap."
112
+ max_frames = (num_interp_frames + 2 - num_overlap_frames) * (
113
+ num_frames // 2
114
+ )
115
+ image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames)
116
+ else:
117
+ image = img_utils.read_image(cfg.data_path)
118
+
119
+ image = img_utils.imresize_max(image, cfg.max_resolution)
120
+ image = img_utils.imcrop_multi(image)
121
+ image_tensor = np.ascontiguousarray(
122
+ [_img.transpose(2, 0, 1) / 255.0 for _img in image]
123
+ )
124
+ image_tensor = torch.from_numpy(image_tensor).to(DEVICE)
125
+
126
+ with torch.no_grad(), torch.autocast(
127
+ device_type=DEVICE_TYPE, dtype=torch.float16
128
+ ):
129
+ pipe_out = pipe(
130
+ image_tensor,
131
+ num_frames=cfg.num_frames,
132
+ num_overlap_frames=cfg.num_overlap_frames,
133
+ num_interp_frames=cfg.num_interp_frames,
134
+ decode_chunk_size=cfg.decode_chunk_size,
135
+ num_inference_steps=cfg.denoise_steps,
136
+ )
137
+
138
+ disparity = pipe_out.disparity
139
+ disparity_colored = pipe_out.disparity_colored
140
+ image = pipe_out.image
141
+ # (N, H, 2 * W, 3)
142
+ merged = np.concatenate(
143
+ [
144
+ image,
145
+ disparity_colored,
146
+ ],
147
+ axis=2,
148
+ )
149
+
150
+ if is_video:
151
+ output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4")
152
+ img_utils.write_video(
153
+ output_path,
154
+ merged,
155
+ fps,
156
+ )
157
+ return output_path
158
+ else:
159
+ output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png")
160
+ img_utils.write_image(
161
+ output_path,
162
+ merged[0],
163
+ )
164
+ return output_path
165
+
166
+
167
+ # Define Gradio interface
168
+ title = "Depth Any Video with Scalable Synthetic Data"
169
+ description = """
170
+ Upload a video or image to perform depth estimation using the Depth Any Video model.
171
+ Adjust the parameters as needed to control the inference process.
172
+ """
173
+
174
+ iface = gr.Interface(
175
+ fn=depth_any_video,
176
+ inputs=[
177
+ gr.File(label="Upload Video/Image"),
178
+ gr.Slider(1, 10, step=1, value=3, label="Denoise Steps"),
179
+ gr.Slider(16, 64, step=1, value=32, label="Number of Frames"),
180
+ gr.Slider(8, 32, step=1, value=16, label="Decode Chunk Size"),
181
+ gr.Slider(8, 32, step=1, value=16, label="Number of Interpolation Frames"),
182
+ gr.Slider(2, 10, step=1, value=6, label="Number of Overlap Frames"),
183
+ gr.Slider(512, 2048, step=32, value=1024, label="Maximum Resolution"),
184
+ ],
185
+ outputs=gr.Video(label="Depth Enhanced Video/Image"),
186
+ title=title,
187
+ description=description,
188
+ examples=[["demos/arch_2.jpg"], ["demos/wooly_mammoth.mp4"]],
189
+ allow_flagging="never",
190
+ analytics_enabled=False,
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ iface.launch(share=True)