Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from PIL import Image | |
from diffusers import DiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDIMScheduler | |
from diffusers.utils import BaseOutput | |
class MarigoldDepthOutput(BaseOutput): | |
depth_np: np.ndarray | |
depth_image: Image.Image | |
class MarigoldPipeline(DiffusionPipeline): | |
def __init__(self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler): | |
super().__init__() | |
self.unet = unet | |
self.vae = vae | |
self.scheduler = scheduler | |
def __call__(self, input_image: Image, denoising_steps: int = 10, save_path: str = None) -> MarigoldDepthOutput: | |
device = self.device | |
# Image preprocessing | |
input_image = input_image.convert("RGB") | |
image = np.asarray(input_image) | |
rgb = np.transpose(image, (2, 0, 1)) | |
rgb_norm = rgb / 255.0 * 2.0 - 1.0 | |
rgb_norm = torch.from_numpy(rgb_norm).to(device) | |
# Encode image | |
rgb_latent = self._encode_rgb(rgb_norm) | |
# Initial depth map (noise) | |
depth_latent = torch.randn(rgb_latent.shape, device=device) | |
# Denoising loop | |
timesteps = self.scheduler.timesteps | |
for t in timesteps: | |
unet_input = torch.cat([rgb_latent, depth_latent], dim=1) | |
noise_pred = self.unet(unet_input, t).sample | |
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample | |
# Decode depth map | |
depth = self._decode_depth(depth_latent) | |
# Scale to [0, 1] and convert to numpy | |
depth = (depth + 1.0) / 2.0 | |
depth_np = depth.cpu().numpy().astype(np.float32) | |
depth_image = (depth_np * 255).astype(np.uint8) | |
depth_image = Image.fromarray(depth_image[0], 'L') # 'L' mode for grayscale image | |
# Save the depth map image if a path is provided | |
if save_path: | |
depth_image.save(save_path) | |
return MarigoldDepthOutput(depth_np=depth_np, depth_image=depth_image) | |
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: | |
h = self.vae.encoder(rgb_in) | |
moments = self.vae.quant_conv(h) | |
mean, _ = torch.chunk(moments, 2, dim=1) | |
rgb_latent = mean * 0.18215 | |
return rgb_latent | |
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: | |
z = self.vae.post_quant_conv(depth_latent) | |
stacked = self.vae.decoder(z) | |
depth_mean = stacked.mean(dim=1, keepdim=True) | |
return depth_mean | |
# Instantiate the model components and the pipeline | |
unet_model = UNet2DConditionModel() | |
vae_model = AutoencoderKL() | |
scheduler = DDIMScheduler() | |
pipeline = MarigoldPipeline(unet=unet_model, vae=vae_model, scheduler=scheduler) | |
# Load an image and predict the depth map | |
input_image = Image.open('path_to_your_image.jpg') | |
output_path = 'path_to_save_image.jpg' # Specify the path where you want to save the depth image | |
output = pipeline(input_image, denoising_steps=10, save_path=output_path) | |