ohjho commited on
Commit
dfcd969
1 Parent(s): 22d0af0

testing DPT app

Browse files
Files changed (4) hide show
  1. .gitignore +129 -0
  2. DPT.py +62 -0
  3. app.py +103 -0
  4. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space doesn't like binary files
2
+ *.jpg
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ data/
7
+ results/
8
+ weights/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ pip-wheel-metadata/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don’t work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # celery beat schedule file
99
+ celerybeat-schedule
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
DPT.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2, torch
2
+ import urllib.request
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ MODEL_DICT = {
7
+ "DPT_Large": "MiDaS v3 - Large (highest accuracy, slowest inference speed)",
8
+ "DPT_Hybrid": "MiDaS v3 - Hybrid (medium accuracy, medium inference speed)",
9
+ "MiDaS_small": "MiDaS v2.1 - Small (lowest accuracy, highest inference speed)"
10
+ }
11
+
12
+ def load_model(model_type = 'DPT_Large'):
13
+ assert model_type in MODEL_DICT.keys(), f'{model_type} is not a valid model_type: {MODEL_DICT.keys()}'
14
+ midas = torch.hub.load("intel-isl/MiDaS", model_type)
15
+
16
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ midas.to(device)
18
+ midas.eval()
19
+
20
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
21
+
22
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
23
+ transform = midas_transforms.dpt_transform
24
+ else:
25
+ transform = midas_transforms.small_transform
26
+ return {
27
+ 'midas': midas, 'device': device, 'transform': transform
28
+ }
29
+
30
+ def inference(img_array_rgb, model_def):
31
+ '''run DPT model and returns a PIL image'''
32
+ # img = cv2.imread(img.name)
33
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
34
+ midas = model_def['midas']
35
+ transform = model_def['transform']
36
+ device = model_def['device']
37
+ input_batch = transform(img_array_rgb).to(device)
38
+
39
+ with torch.no_grad():
40
+ prediction = midas(input_batch)
41
+
42
+ prediction = torch.nn.functional.interpolate(
43
+ prediction.unsqueeze(1),
44
+ size=img_array_rgb.shape[:2],
45
+ mode="bicubic",
46
+ align_corners=False,
47
+ ).squeeze()
48
+
49
+ output = prediction.cpu().numpy()
50
+ formatted = (output * 255 / np.max(output)).astype('uint8')
51
+ img = Image.fromarray(formatted)
52
+ return img
53
+
54
+ # inputs = gr.inputs.Image(type='file', label="Original Image")
55
+ # outputs = gr.outputs.Image(type="pil",label="Output Image")
56
+
57
+ # title = "DPT-Large"
58
+ # description = "Gradio demo for DPT-Large:Vision Transformers for Dense Prediction.To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
59
+ # article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2103.13413' target='_blank'>Vision Transformers for Dense Prediction</a> | <a href='https://github.com/intel-isl/MiDaS' target='_blank'>Github Repo</a></p>"
60
+ #
61
+ # examples=[['dog.jpg']]
62
+ # gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False,examples=examples, enable_queue=True).launch(debug=True)
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import os, sys, io
4
+ import urllib.request as urllib
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ import DPT
9
+
10
+ ### Some Utils Functions ###
11
+ def get_image(st_asset = st.sidebar, as_np_arr = False, extension_list = ['jpg', 'jpeg', 'png']):
12
+ image_url, image_fh = None, None
13
+ if st_asset.checkbox('use image URL?'):
14
+ image_url = st_asset.text_input("Enter Image URL")
15
+ else:
16
+ image_fh = st_asset.file_uploader(label = "Update your image", type = extension_list)
17
+
18
+ im = None
19
+ if image_url:
20
+ response = urllib.urlopen(image_url)
21
+ im = Image.open(io.BytesIO(bytearray(response.read())))
22
+ elif image_fh:
23
+ im = Image.open(image_fh)
24
+
25
+ if im and as_np_arr:
26
+ im = np.array(im)
27
+ return im
28
+
29
+ def show_miro_logo(use_column_width = False, width = 100, st_asset= st.sidebar):
30
+ logo_url = 'https://miro.medium.com/max/1400/0*qLL-32srlq6Y_iTm.png'
31
+ st_asset.image(logo_url, use_column_width = use_column_width, channels = 'BGR', output_format = 'PNG', width = width)
32
+
33
+ def im_draw_bbox(pil_im, x0, y0, x1, y1, color = 'black', width = 3, caption = None,
34
+ bbv_label_only = False):
35
+ '''
36
+ draw bounding box on the input image pil_im in-place
37
+ Args:
38
+ color: color name as read by Pillow.ImageColor
39
+ use_bbv: use bbox_visualizer
40
+ '''
41
+ import bbox_visualizer as bbv
42
+ if any([type(i)== float for i in [x0,y0,x1,y1]]):
43
+ warnings.warn(f'im_draw_bbox: at least one of x0,y0,x1,y1 is of the type float and is converted to int.')
44
+ x0 = int(x0)
45
+ y0 = int(y0)
46
+ x1 = int(x1)
47
+ y1 = int(y1)
48
+
49
+ if bbv_label_only:
50
+ if caption:
51
+ im_array = bbv.draw_flag_with_label(np.array(pil_im),
52
+ label = caption,
53
+ bbox = [x0,y0,x1,y1],
54
+ line_color = ImageColor.getrgb(color),
55
+ text_bg_color = ImageColor.getrgb(color)
56
+ )
57
+ else:
58
+ raise ValueError(f'im_draw_bbox: bbv_label_only is True but caption is None')
59
+ else:
60
+ im_array = bbv.draw_rectangle(np.array(pil_im),
61
+ bbox = [x0, y0, x1, y1],
62
+ bbox_color = ImageColor.getrgb(color),
63
+ thickness = width
64
+ )
65
+ im_array = bbv.add_label(
66
+ im_array, label = caption,
67
+ bbox = [x0,y0,x1,y1],
68
+ text_bg_color = ImageColor.getrgb(color)
69
+ )if caption else im_array
70
+ return Image.fromarray(im_array)
71
+
72
+ ### Streamlit App ###
73
+
74
+ def mod_DPT(pil_im, model_def):
75
+ depth_im = DPT.inference(img_array_rgb = np.array(pil_im), model_def = model_def)
76
+ return depth_im
77
+
78
+ def Main(model_dict):
79
+ st.set_page_config(layout = 'wide')
80
+ l_col, r_col = st.columns(2)
81
+ show_miro_logo(st_asset = l_col)
82
+ with l_col.expander('Monocular Depth: CNN vs Transformers'):
83
+ st.info(f'''
84
+ Comparsion of two models: [BTS (CNN)](https://github.com/ErenBalatkan/Bts-PyTorch)
85
+ and [DPT (Transformer)](https://huggingface.co/Intel/dpt-large)
86
+ ''')
87
+
88
+ im = get_image(st_asset = r_col.expander('Input Image', expanded = True), extension_list = ['jpg','jpeg'])
89
+ model_name = r_col.selectbox('Pick Model', options = ['DPT','BTS'])
90
+
91
+ if im:
92
+ model_def = DPT.load_model()
93
+ d_im = mod_DPT(pil_im = im, model_def=model_def)
94
+
95
+ l_col, r_col = st.columns(2)
96
+ l_col.image(im, caption = 'Input Image')
97
+ r_col.image(saliency_im, caption = 'Depth Map')
98
+ else:
99
+ st.warning(f'please provide an image :point_up:')
100
+
101
+ if __name__ == '__main__':
102
+ model_dict = load_model()
103
+ Main(model_dict = model_dict)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python-headless>=4.5.5.64
2
+ torch==1.8.0
3
+ #matplotlib==3.1.3
4
+ numpy>=1.15.2
5
+ Pillow>=6.2.0
6
+ # DPT
7
+ timm==0.5.4
8
+ # BTS
9
+ albumentations>=1.1.0
10
+ torchvision==0.9.0