File size: 3,955 Bytes
bd86ed9
 
 
 
 
 
 
 
3fb6608
bd86ed9
 
 
 
3fbdaa2
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
c19a5d3
 
 
 
 
 
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fb6608
 
 
 
3034d2d
f92af9d
3fb6608
 
bd2c8ae
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Normalize
import tempfile
from gradio_imageslider import ImageSlider

from iebins.networks.NewCRFDepth import NewCRFDepth
from iebins.util.transfrom import Resize, NormalizeImage, PrepareForNet

css = """
#img-display-container {
    max-height: 100vh;
    }
#img-display-input {
    max-height: 80vh;
    }
#img-display-output {
    max-height: 80vh;
    }
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NewCRFDepth(version='large07', inv_depth=False,
                    max_depth=10).to(DEVICE).eval()
model = torch.nn.DataParallel(model)
checkpoint = torch.load('checkpoints/nyu_L.pth',
                        map_location=torch.device(DEVICE))
model.load_state_dict(checkpoint['model'])

title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""

transform = Compose([
    Resize(
        width=518,
        height=518,
        resize_target=False,
        keep_aspect_ratio=True,
        ensure_multiple_of=14,
        resize_method='lower_bound',
        image_interpolation_method=cv2.INTER_CUBIC,
    ),
    NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    PrepareForNet(),
])


@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
    return model(image)


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Row():
        input_image = gr.Image(label="Input Image",
                               type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(
            label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
    raw_file = gr.File(
        label="16-bit raw depth (can be considered as disparity)")
    submit = gr.Button("Submit")

    def on_submit(image):
        original_image = image.copy()

        h, w = image.shape[:2]

        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        # image = transform({'image': image})['image']
        # image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

        image = np.asarray(image, dtype=np.float32) / 255.0
        image = torch.from_numpy(image.transpose((2, 0, 1)))
        image = Normalize(mean=[0.485, 0.456, 0.406], std=[
            0.229, 0.224, 0.225])(image)
        image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

        depth = predict_depth(model, image)
        depth = F.interpolate(depth[None], (h, w),
                              mode='bilinear', align_corners=False)[0, 0]

        raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
        tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        raw_depth.save(tmp.name)

        depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
        depth = depth.cpu().numpy().astype(np.uint8)
        colored_depth = cv2.applyColorMap(
            depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]

        return [(original_image, colored_depth), tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[
                 depth_image_slider, raw_file])

    example_files = os.listdir('examples')
    example_files.sort()
    example_files = [os.path.join('examples', filename)
                     for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
                           depth_image_slider, raw_file], fn=on_submit, cache_examples=False)


if __name__ == '__main__':
    demo.queue().launch()