File size: 4,426 Bytes
bd86ed9 3fb6608 bd86ed9 3fbdaa2 bd86ed9 c19a5d3 92224a7 c19a5d3 92224a7 bd86ed9 3fb6608 3034d2d f92af9d 3fb6608 7a8232f 20bb053 bd86ed9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Normalize
import tempfile
from gradio_imageslider import ImageSlider
from iebins.networks.NewCRFDepth import NewCRFDepth
from iebins.util.transfrom import Resize, NormalizeImage, PrepareForNet
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NewCRFDepth(version='large07', inv_depth=False,
max_depth=10, pretrained=None).to(DEVICE).eval()
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape)
for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))
model = torch.nn.DataParallel(model)
checkpoint = torch.load('checkpoints/nyu_L.pth',
map_location=torch.device(DEVICE))
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format('checkpoints/nyu_L.pth'))
title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""
transform = Compose([
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
return model(image)
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
input_image = gr.Image(label="Input Image",
type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(
label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
raw_file = gr.File(
label="16-bit raw depth (can be considered as disparity)")
submit = gr.Button("Submit")
def on_submit(image):
original_image = image.copy()
h, w = image.shape[:2]
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
# image = transform({'image': image})['image']
# image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
image = np.asarray(image, dtype=np.float32) / 255.0
image = torch.from_numpy(image.transpose((2, 0, 1)))
image = Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])(image)
# image = torch.from_numpy(image).unsqueeze(0)
image = torch.autograd.Variable(image.unsqueeze(0))
depth = predict_depth(model, image)
depth = F.interpolate(depth[None], (h, w),
mode='bilinear', align_corners=False)[0, 0]
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp.name)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.cpu().numpy().astype(np.uint8)
colored_depth = cv2.applyColorMap(
depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
return [(original_image, colored_depth), tmp.name]
submit.click(on_submit, inputs=[input_image], outputs=[
depth_image_slider, raw_file])
example_files = os.listdir('examples')
example_files.sort()
example_files = [os.path.join('examples', filename)
for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
depth_image_slider, raw_file], fn=on_submit, cache_examples=False)
if __name__ == '__main__':
demo.queue().launch()
|