Spaces:
Runtime error
Runtime error
File size: 5,070 Bytes
bd86ed9 3fb6608 bd86ed9 0f1bbf6 bd86ed9 3fbdaa2 8b0757c bd86ed9 c19a5d3 92224a7 c19a5d3 92224a7 bd86ed9 3fb6608 3034d2d f92af9d 3fb6608 7a8232f dccb91c c9de5c0 dccb91c c9de5c0 b3a5f53 c9de5c0 f5ce9f2 c4beb64 f5ce9f2 c4beb64 f5ce9f2 b3a5f53 bd86ed9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import cv2
import numpy as np
import os
from PIL import Image
import spaces
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Normalize
import tempfile
from gradio_imageslider import ImageSlider
import matplotlib.pyplot as plt
from iebins.networks.NewCRFDepth import NewCRFDepth
from iebins.util.transfrom import Resize, NormalizeImage, PrepareForNet
from iebins.utils import post_process_depth, flip_lr
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NewCRFDepth(version='large07', inv_depth=False,
max_depth=10, pretrained=None).to(DEVICE).eval()
model.train()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
print("== Total number of parameters: {}".format(num_params))
num_params_update = sum([np.prod(p.shape)
for p in model.parameters() if p.requires_grad])
print("== Total number of learning parameters: {}".format(num_params_update))
model = torch.nn.DataParallel(model)
checkpoint = torch.load('checkpoints/nyu_L.pth',
map_location=torch.device(DEVICE))
model.load_state_dict(checkpoint['model'])
print("== Loaded checkpoint '{}'".format('checkpoints/nyu_L.pth'))
title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""
transform = Compose([
Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
@spaces.GPU
@torch.no_grad()
def predict_depth(model, image):
return model(image)
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
input_image = gr.Image(label="Input Image",
type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(
label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
raw_file = gr.File(
label="16-bit raw depth (can be considered as disparity)")
submit = gr.Button("Submit")
def on_submit(image):
original_image = image.copy()
h, w = image.shape[:2]
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
# image = transform({'image': image})['image']
# image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
image = np.asarray(image, dtype=np.float32) / 255.0
image = torch.from_numpy(image.transpose((2, 0, 1)))
image = Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])(image)
# image = torch.from_numpy(image).unsqueeze(0)
with torch.no_grad():
image = torch.autograd.Variable(image.unsqueeze(0))
print("== Processing image")
pred_depths_r_list, _, _ = model(image)
image_flipped = flip_lr(image)
pred_depths_r_list_flipped, _, _ = model(image_flipped)
pred_depth = post_process_depth(
pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
print("== Finished processing image")
# pred_depth = pred_depth.cpu().numpy().squeeze()
# output_image = plt.imsave('depth.png', pred_depth, cmap='jet')
# Remove batch and channel dimensions
pred_depth_squeezed = pred_depth.squeeze()
# If pred_depth is in an acceptable range for 16-bit representation
# No need for further normalization
# pred_depth_16bit = pred_depth_squeezed.cpu().numpy().astype('uint16')
# Convert to PIL image
raw_depth = Image.fromarray(pred_depth_squeezed)
# Continue with your file saving operations
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
raw_depth.save(tmp.name)
return [(original_image, original_image), tmp.name]
submit.click(on_submit, inputs=[input_image], outputs=[
depth_image_slider, raw_file])
example_files = os.listdir('examples')
example_files.sort()
example_files = [os.path.join('examples', filename)
for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
depth_image_slider, raw_file], fn=on_submit, cache_examples=False)
if __name__ == '__main__':
demo.queue().launch()
|