ohjho commited on
Commit
a327756
1 Parent(s): 5a54fe2

updated the app for DPT model

Browse files
Files changed (2) hide show
  1. DPT.py +5 -4
  2. app.py +29 -12
DPT.py CHANGED
@@ -14,6 +14,7 @@ def load_model(model_type = 'DPT_Large'):
14
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
15
 
16
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
17
  midas.to(device)
18
  midas.eval()
19
 
@@ -27,13 +28,13 @@ def load_model(model_type = 'DPT_Large'):
27
  'midas': midas, 'device': device, 'transform': transform
28
  }
29
 
30
- def inference(img_array_rgb, model_def):
31
  '''run DPT model and returns a PIL image'''
32
  # img = cv2.imread(img.name)
33
  # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
34
- midas = model_def['midas']
35
- transform = model_def['transform']
36
- device = model_def['device']
37
  input_batch = transform(img_array_rgb).to(device)
38
 
39
  with torch.no_grad():
 
14
  midas = torch.hub.load("intel-isl/MiDaS", model_type)
15
 
16
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ print(f'---DPT will use device: {device}')
18
  midas.to(device)
19
  midas.eval()
20
 
 
28
  'midas': midas, 'device': device, 'transform': transform
29
  }
30
 
31
+ def inference(img_array_rgb, model_obj):
32
  '''run DPT model and returns a PIL image'''
33
  # img = cv2.imread(img.name)
34
  # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
35
+ midas = model_obj['midas']
36
+ transform = model_obj['transform']
37
+ device = model_obj['device']
38
  input_batch = transform(img_array_rgb).to(device)
39
 
40
  with torch.no_grad():
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
 
3
- import os, sys, io
4
  import urllib.request as urllib
5
  import numpy as np
6
  from PIL import Image
7
 
8
- import DPT
9
 
10
  ### Some Utils Functions ###
11
  def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
@@ -70,12 +70,31 @@ def im_draw_bbox(pil_im, x0, y0, x1, y1, color = 'black', width = 3, caption = N
70
  return Image.fromarray(im_array)
71
 
72
  ### Streamlit App ###
 
 
 
 
 
 
 
73
 
74
- def mod_DPT(pil_im, model_def):
75
- depth_im = DPT.inference(img_array_rgb = np.array(pil_im), model_def = model_def)
 
 
 
 
 
 
 
 
 
 
 
 
76
  return depth_im
77
 
78
- def Main(model_dict):
79
  st.set_page_config(layout = 'wide')
80
  l_col, r_col = st.columns(2)
81
  show_miro_logo(st_asset = l_col)
@@ -84,20 +103,18 @@ def Main(model_dict):
84
  Comparsion of two models: [BTS (CNN)](https://github.com/ErenBalatkan/Bts-PyTorch)
85
  and [DPT (Transformer)](https://huggingface.co/Intel/dpt-large)
86
  ''')
87
-
88
  im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
89
- model_name = r_col.selectbox('Pick Model', options = ['DPT','BTS'])
90
 
91
  if im:
92
- model_def = DPT.load_model()
93
- d_im = mod_DPT(pil_im = im, model_def=model_def)
94
 
95
  l_col, r_col = st.columns(2)
96
  l_col.image(im, caption = 'Input Image')
97
- r_col.image(saliency_im, caption = 'Depth Map')
98
  else:
99
  st.warning(f'please provide an image :point_up:')
100
 
101
  if __name__ == '__main__':
102
- model_dict = load_model()
103
- Main(model_dict = model_dict)
 
1
  import streamlit as st
2
 
3
+ import os, sys, io, time
4
  import urllib.request as urllib
5
  import numpy as np
6
  from PIL import Image
7
 
8
+ import DPT, BTS_infer
9
 
10
  ### Some Utils Functions ###
11
  def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
 
70
  return Image.fromarray(im_array)
71
 
72
  ### Streamlit App ###
73
+ @st.cache(allow_output_mutation = True)
74
+ def get_model_zoo():
75
+ model_zoo = {
76
+ 'DPT': {'infer_func': DPT.inference,'model': DPT.load_model()},
77
+ # 'BTS': {'infer_func': BTS_infer.inference,'model': BTS_infer.get_model()}
78
+ }
79
+ return model_zoo
80
 
81
+ @st.cache(suppress_st_warning=True)
82
+ def mono_depth(pil_im, model_name):
83
+ s_time = time.time()
84
+ model_zoo = get_model_zoo()
85
+ infer_func = model_zoo[model_name]['infer_func']
86
+ model_obj = model_zoo[model_name]['model']
87
+ depth_im = infer_func(img_array_rgb = np.array(pil_im),
88
+ model_obj = model_obj)
89
+ st.info(f'''
90
+ model name: {model_name}\n
91
+ inference time: `{round(time.time()-s_time,2)}` seconds\n
92
+ depth image shape: {np.array(depth_im).shape}\n
93
+ depth image type: {type(depth_im)}
94
+ ''')
95
  return depth_im
96
 
97
+ def Main():
98
  st.set_page_config(layout = 'wide')
99
  l_col, r_col = st.columns(2)
100
  show_miro_logo(st_asset = l_col)
 
103
  Comparsion of two models: [BTS (CNN)](https://github.com/ErenBalatkan/Bts-PyTorch)
104
  and [DPT (Transformer)](https://huggingface.co/Intel/dpt-large)
105
  ''')
106
+ model_zoo = get_model_zoo()
107
  im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
108
+ model_name = l_col.selectbox('Pick Model', options = list(model_zoo.keys()))
109
 
110
  if im:
111
+ d_im = mono_depth(pil_im = im, model_name=model_name)
 
112
 
113
  l_col, r_col = st.columns(2)
114
  l_col.image(im, caption = 'Input Image')
115
+ r_col.image(d_im, caption = 'Depth Map')
116
  else:
117
  st.warning(f'please provide an image :point_up:')
118
 
119
  if __name__ == '__main__':
120
+ Main()