akhaliq HF staff commited on
Commit
2e549d0
1 Parent(s): e9d914e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -28
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import matplotlib.pyplot as plt
6
  import subprocess
7
  import spaces
 
8
 
9
  # Run the script to get pretrained models
10
  subprocess.run(["bash", "get_pretrained_models.sh"])
@@ -15,39 +16,52 @@ model.eval()
15
 
16
  @spaces.GPU(duration=120)
17
  def predict_depth(input_image):
18
- # Preprocess the image
19
- result = depth_pro.load_rgb(input_image) # Removed .name
20
- image = result[0]
21
- f_px = result[-1] # Assuming f_px is the last item in the returned tuple
22
- image = transform(image)
23
-
24
- # Run inference
25
- prediction = model.infer(image, f_px=f_px)
26
- depth = prediction["depth"] # Depth in [m]
27
- focallength_px = prediction["focallength_px"] # Focal length in pixels
28
-
29
- # Normalize depth for visualization
30
- depth_normalized = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
31
-
32
- # Create a color map
33
- plt.figure(figsize=(10, 10))
34
- plt.imshow(depth_normalized, cmap='viridis')
35
- plt.colorbar(label='Depth')
36
- plt.title('Predicted Depth Map')
37
- plt.axis('off')
38
-
39
- # Save the plot to a file
40
- output_path = "depth_map.png"
41
- plt.savefig(output_path)
42
- plt.close()
43
-
44
- return output_path, f"Focal length: {focallength_px:.2f} pixels"
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # Create Gradio interface
47
  iface = gr.Interface(
48
  fn=predict_depth,
49
  inputs=gr.Image(type="filepath"),
50
- outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length")],
51
  title="Depth Prediction Demo",
52
  description="Upload an image to predict its depth map and focal length."
53
  )
 
5
  import matplotlib.pyplot as plt
6
  import subprocess
7
  import spaces
8
+ import torch
9
 
10
  # Run the script to get pretrained models
11
  subprocess.run(["bash", "get_pretrained_models.sh"])
 
16
 
17
  @spaces.GPU(duration=120)
18
  def predict_depth(input_image):
19
+ try:
20
+ # Preprocess the image
21
+ result = depth_pro.load_rgb(input_image)
22
+ image = result[0]
23
+ f_px = result[-1] # Assuming f_px is the last item in the returned tuple
24
+ image = transform(image)
25
+
26
+ # Run inference
27
+ prediction = model.infer(image, f_px=f_px)
28
+ depth = prediction["depth"] # Depth in [m]
29
+ focallength_px = prediction["focallength_px"] # Focal length in pixels
30
+
31
+ # Convert depth to numpy array if it's a torch tensor
32
+ if isinstance(depth, torch.Tensor):
33
+ depth = depth.cpu().numpy()
34
+
35
+ # Ensure depth is a 2D numpy array
36
+ if depth.ndim != 2:
37
+ depth = depth.squeeze()
38
+
39
+ # Normalize depth for visualization
40
+ depth_min = np.min(depth)
41
+ depth_max = np.max(depth)
42
+ depth_normalized = (depth - depth_min) / (depth_max - depth_min)
43
+
44
+ # Create a color map
45
+ plt.figure(figsize=(10, 10))
46
+ plt.imshow(depth_normalized, cmap='viridis')
47
+ plt.colorbar(label='Depth')
48
+ plt.title('Predicted Depth Map')
49
+ plt.axis('off')
50
+
51
+ # Save the plot to a file
52
+ output_path = "depth_map.png"
53
+ plt.savefig(output_path)
54
+ plt.close()
55
+
56
+ return output_path, f"Focal length: {focallength_px:.2f} pixels"
57
+ except Exception as e:
58
+ return None, f"An error occurred: {str(e)}"
59
 
60
  # Create Gradio interface
61
  iface = gr.Interface(
62
  fn=predict_depth,
63
  inputs=gr.Image(type="filepath"),
64
+ outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")],
65
  title="Depth Prediction Demo",
66
  description="Upload an image to predict its depth map and focal length."
67
  )