umuthopeyildirim commited on
Commit
db961ba
1 Parent(s): 92224a7

Refactor depth prediction and post-processing in app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -93,11 +93,15 @@ with gr.Blocks(css=css) as demo:
93
  # image = torch.from_numpy(image).unsqueeze(0)
94
  image = torch.autograd.Variable(image.unsqueeze(0))
95
 
96
- depth = predict_depth(model, image)
97
- depth = F.interpolate(depth[None], (h, w),
98
- mode='bilinear', align_corners=False)[0, 0]
 
 
99
 
100
- raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
 
 
101
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
102
  raw_depth.save(tmp.name)
103
 
 
93
  # image = torch.from_numpy(image).unsqueeze(0)
94
  image = torch.autograd.Variable(image.unsqueeze(0))
95
 
96
+ pred_depths_r_list, _, _ = predict_depth(model, image)
97
+ image_flipped = flip_lr(image)
98
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
99
+ pred_depth = post_process_depth(
100
+ pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
101
 
102
+ pred_depth = pred_depth.cpu().numpy().squeeze()
103
+
104
+ raw_depth = Image.fromarray(pred_depth.cpu().numpy().astype('uint16'))
105
  tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
106
  raw_depth.save(tmp.name)
107