File size: 3,571 Bytes
b19928f
 
a394b1d
b19928f
 
d6c2352
 
2e549d0
bdf07c0
 
5a3dc03
e9d914e
d6c2352
b19928f
d3f5a26
 
b19928f
 
d3f5a26
b19928f
 
30d77b6
26691a8
 
 
 
 
 
 
 
bdf07c0
 
 
 
26691a8
d3f5a26
b19928f
bdf07c0
2e549d0
26691a8
bdf07c0
26691a8
2e549d0
bdf07c0
2e549d0
 
 
d3f5a26
2e549d0
 
 
 
 
 
 
 
 
 
 
 
 
 
4280f5a
 
 
 
 
 
2e549d0
30d77b6
4280f5a
 
 
2e549d0
 
 
4280f5a
30d77b6
2e549d0
 
 
 
 
bdf07c0
 
 
 
b19928f
596c382
 
 
 
 
b19928f
 
 
 
30d77b6
596c382
30d77b6
 
c0a3f1e
596c382
 
b19928f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()

def resize_image(image_path, max_size=1536):
    with Image.open(image_path) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        # Resize the input image
        temp_file = resize_image(input_image)
        
        # Preprocess the image
        result = depth_pro.load_rgb(temp_file)
        image = result[0]
        f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
        image = transform(image)
        image = image.to(device)

        # Run inference
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth to numpy array if it's a torch tensor
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure depth is a 2D numpy array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Calculate inverse depth
        inverse_depth = 1.0 / depth

        # Clip inverse depth to 0-10 range
        inverse_depth_clipped = np.clip(inverse_depth, 0, 10)

        # Create a color map
        plt.figure(figsize=(15.36, 15.36), dpi=100)  # Set figure size to 1536x1536 pixels
        plt.imshow(inverse_depth_clipped, cmap='viridis')
        plt.colorbar(label='Inverse Depth')
        plt.title('Predicted Inverse Depth Map')
        plt.axis('off')
        
        # Save the plot to a file
        output_path = "inverse_depth_map.png"
        plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0)
        plt.close()

        return output_path, f"Focal length: {focallength_px:.2f} pixels"
    except Exception as e:
        return None, f"An error occurred: {str(e)}"
    finally:
        # Clean up the temporary file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

# Example images
example_images = [
    "examples/lemur.jpg",
]

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Inverse Depth Map", height=768, width=768),
        gr.Textbox(label="Focal Length or Error Message")
    ],
    title="DepthPro Demo",
    description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its inverse depth map and focal length. Large images will be automatically resized to 1536x1536 pixels.",
    examples=example_images
)

# Launch the interface
iface.launch()