ascarlettvfx commited on
Commit
a49153f
1 Parent(s): 5228808

Create marigold_depth_estimation.py

Browse files
Files changed (1) hide show
  1. marigold_depth_estimation.py +79 -0
marigold_depth_estimation.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler
5
+ from diffusers.utils import BaseOutput
6
+
7
+ class MarigoldDepthOutput(BaseOutput):
8
+ depth_np: np.ndarray
9
+ depth_image: Image.Image
10
+
11
+ class MarigoldPipeline(DiffusionPipeline):
12
+ def __init__(self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler):
13
+ super().__init__()
14
+ self.unet = unet
15
+ self.vae = vae
16
+ self.scheduler = scheduler
17
+
18
+ @torch.no_grad()
19
+ def __call__(self, input_image: Image, denoising_steps: int = 10, save_path: str = None) -> MarigoldDepthOutput:
20
+ device = self.device
21
+
22
+ # Image preprocessing
23
+ input_image = input_image.convert("RGB")
24
+ image = np.asarray(input_image)
25
+ rgb = np.transpose(image, (2, 0, 1))
26
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0
27
+ rgb_norm = torch.from_numpy(rgb_norm).to(device)
28
+
29
+ # Encode image
30
+ rgb_latent = self._encode_rgb(rgb_norm)
31
+
32
+ # Initial depth map (noise)
33
+ depth_latent = torch.randn(rgb_latent.shape, device=device)
34
+
35
+ # Denoising loop
36
+ timesteps = self.scheduler.timesteps
37
+ for t in timesteps:
38
+ unet_input = torch.cat([rgb_latent, depth_latent], dim=1)
39
+ noise_pred = self.unet(unet_input, t).sample
40
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
41
+
42
+ # Decode depth map
43
+ depth = self._decode_depth(depth_latent)
44
+
45
+ # Scale to [0, 1] and convert to numpy
46
+ depth = (depth + 1.0) / 2.0
47
+ depth_np = depth.cpu().numpy().astype(np.float32)
48
+ depth_image = (depth_np * 255).astype(np.uint8)
49
+ depth_image = Image.fromarray(depth_image[0], 'L') # 'L' mode for grayscale image
50
+
51
+ # Save the depth map image if a path is provided
52
+ if save_path:
53
+ depth_image.save(save_path)
54
+
55
+ return MarigoldDepthOutput(depth_np=depth_np, depth_image=depth_image)
56
+
57
+ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
58
+ h = self.vae.encoder(rgb_in)
59
+ moments = self.vae.quant_conv(h)
60
+ mean, _ = torch.chunk(moments, 2, dim=1)
61
+ rgb_latent = mean * 0.18215
62
+ return rgb_latent
63
+
64
+ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
65
+ z = self.vae.post_quant_conv(depth_latent)
66
+ stacked = self.vae.decoder(z)
67
+ depth_mean = stacked.mean(dim=1, keepdim=True)
68
+ return depth_mean
69
+
70
+ # Instantiate the model components and the pipeline
71
+ unet_model = UNet2DConditionModel()
72
+ vae_model = AutoencoderKL()
73
+ scheduler = DDIMScheduler()
74
+ pipeline = MarigoldPipeline(unet=unet_model, vae=vae_model, scheduler=scheduler)
75
+
76
+ # Load an image and predict the depth map
77
+ input_image = Image.open('path_to_your_image.jpg')
78
+ output_path = 'path_to_save_image.jpg' # Specify the path where you want to save the depth image
79
+ output = pipeline(input_image, denoising_steps=10, save_path=output_path)