File size: 4,506 Bytes
bd86ed9
 
 
 
 
 
 
 
3fb6608
bd86ed9
 
0f1bbf6
bd86ed9
 
3fbdaa2
8b0757c
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
c19a5d3
92224a7
 
 
 
 
 
 
 
c19a5d3
 
 
 
92224a7
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fb6608
 
 
 
3034d2d
f92af9d
3fb6608
 
7a8232f
dccb91c
 
bd86ed9
dccb91c
 
 
 
 
bd86ed9
dccb91c
f1ab9e0
2a2d2f8
dccb91c
f1ab9e0
bd86ed9
f1ab9e0
bd86ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Normalize
import tempfile
from gradio_imageslider import ImageSlider
import matplotlib.pyplot as plt

from iebins.networks.NewCRFDepth import NewCRFDepth
from iebins.util.transfrom import Resize, NormalizeImage, PrepareForNet
from iebins.utils import post_process_depth, flip_lr

css = """
#img-display-container {
    max-height: 100vh;
    }
#img-display-input {
    max-height: 80vh;
    }
#img-display-output {
    max-height: 80vh;
    }
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NewCRFDepth(version='large07', inv_depth=False,
                    max_depth=10, pretrained=None).to(DEVICE).eval()
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape)
                        for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))

model = torch.nn.DataParallel(model)
checkpoint = torch.load('checkpoints/nyu_L.pth',
                        map_location=torch.device(DEVICE))
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format('checkpoints/nyu_L.pth'))

title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""

transform = Compose([
    Resize(
        width=518,
        height=518,
        resize_target=False,
        keep_aspect_ratio=True,
        ensure_multiple_of=14,
        resize_method='lower_bound',
        image_interpolation_method=cv2.INTER_CUBIC,
    ),
    NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    PrepareForNet(),
])


@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
    return model(image)


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    with gr.Row():
        input_image = gr.Image(label="Input Image",
                               type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(
            label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
    raw_file = gr.File(
        label="16-bit raw depth (can be considered as disparity)")
    submit = gr.Button("Submit")

    def on_submit(image):
        original_image = image.copy()

        h, w = image.shape[:2]

        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
        # image = transform({'image': image})['image']
        # image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

        image = np.asarray(image, dtype=np.float32) / 255.0
        image = torch.from_numpy(image.transpose((2, 0, 1)))
        image = Normalize(mean=[0.485, 0.456, 0.406], std=[
            0.229, 0.224, 0.225])(image)
        # image = torch.from_numpy(image).unsqueeze(0)
        with torch.no_grad():
            image = torch.autograd.Variable(image.unsqueeze(0))

            pred_depths_r_list, _, _ = model(image)
            image_flipped = flip_lr(image)
            pred_depths_r_list_flipped, _, _ = model(image_flipped)
            pred_depth = post_process_depth(
                pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])

            pred_depth = pred_depth.cpu().numpy().squeeze()
            output_image = plt.imsave('depth.png', pred_depth, cmap='jet')

            tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
            output_image.save(tmp.name)

            return [(original_image, output_image), tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[
                 depth_image_slider, raw_file])

    example_files = os.listdir('examples')
    example_files.sort()
    example_files = [os.path.join('examples', filename)
                     for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
                           depth_image_slider, raw_file], fn=on_submit, cache_examples=False)


if __name__ == '__main__':
    demo.queue().launch()