vijul.shah commited on
Commit
99cd14f
1 Parent(s): 3733e70

Can download data as csv file and some UI changes

Browse files
Files changed (3) hide show
  1. app.py +25 -216
  2. app_utils.py +370 -5
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,251 +1,60 @@
1
- import os
2
  import sys
3
- import tempfile
4
  import os.path as osp
5
- from PIL import Image
6
- from io import BytesIO
7
- import numpy as np
8
- import pandas as pd
9
  import streamlit as st
10
- from PIL import ImageOps
11
- from matplotlib import pyplot as plt
12
- import altair as alt
13
 
14
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
15
  sys.path.append(root_path)
16
 
17
  from registry_utils import import_registered_modules
18
  from app_utils import (
19
- extract_frames,
20
  is_image,
21
  is_video,
22
- convert_diameter,
23
- overlay_text_on_frame,
24
- process_frames,
25
- process_video,
26
- resize_frame,
 
 
27
  )
28
 
29
  import_registered_modules()
30
 
31
- CAM_METHODS = ["CAM"]
32
- TV_MODELS = ["ResNet18", "ResNet50"]
33
- SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
34
- UPSCALE = [2, 4]
35
- UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
36
- LABEL_MAP = ["left_pupil", "right_pupil"]
37
-
38
 
39
  def main():
40
- st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
41
- st.title("EyeDentify Playground")
42
- cols = st.columns((1, 1))
43
- cols[0].header("Input")
44
- cols[-1].header("Prediction")
45
-
46
- st.sidebar.title("Upload Face or Eye")
47
- uploaded_file = st.sidebar.file_uploader(
48
- "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
49
- )
50
 
51
  if uploaded_file is not None:
52
  file_extension = uploaded_file.name.split(".")[-1]
 
53
 
54
  if is_image(file_extension):
55
- input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
56
- # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
57
- input_img = ImageOps.exif_transpose(input_img)
58
- input_img = resize_frame(input_img, max_width=640, max_height=480)
59
- input_img = resize_frame(input_img, max_width=640, max_height=480)
60
- cols[0].image(input_img, use_column_width=True)
61
- st.session_state.total_frames = 1
62
-
63
  elif is_video(file_extension):
64
- tfile = tempfile.NamedTemporaryFile(delete=False)
65
- tfile.write(uploaded_file.read())
66
- video_path = tfile.name
67
- video_frames = extract_frames(video_path)
68
- cols[0].video(video_path)
69
- st.session_state.total_frames = len(video_frames)
70
-
71
- st.session_state.current_frame = 0
72
- st.session_state.frame_placeholder = cols[0].empty()
73
- txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
74
- st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
75
-
76
- st.sidebar.title("Setup")
77
- pupil_selection = st.sidebar.selectbox(
78
- "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
79
- )
80
- tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models")
81
 
82
- blink_detection = st.sidebar.checkbox("Detect Blinks")
83
-
84
- st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
85
 
86
  if st.sidebar.button("Predict Diameter & Compute CAM"):
87
  if uploaded_file is None:
88
  st.sidebar.error("Please upload an image or video")
89
  else:
90
  with st.spinner("Analyzing..."):
91
-
92
- if is_image(file_extension):
93
- input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
94
- cols,
95
- [input_img],
96
- tv_model,
97
- pupil_selection,
98
- cam_method=CAM_METHODS[-1],
99
- blink_detection=blink_detection,
100
  )
101
- # for ff in face_frames:
102
- # if ff["has_face"]:
103
- # cols[1].image(face_frames[0]["img"], use_column_width=True)
104
-
105
- input_frames_keys = input_frames.keys()
106
- video_cols = cols[1].columns(len(input_frames_keys))
107
- for i, eye_type in enumerate(input_frames_keys):
108
- video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
109
-
110
- output_frames_keys = output_frames.keys()
111
- fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
112
- for i, eye_type in enumerate(output_frames_keys):
113
- height, width, c = output_frames[eye_type][0].shape
114
- video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
115
-
116
- frame = np.zeros((height, width, c), dtype=np.uint8)
117
- text = f"{predicted_diameters[eye_type][0]:.2f}"
118
- frame = overlay_text_on_frame(frame, text)
119
- video_cols[i].image(frame, use_column_width=True)
120
-
121
- elif is_video(file_extension):
122
- output_video_path = f"{root_path}/tmp.webm"
123
- input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
124
- cols,
125
- video_frames,
126
- tv_model,
127
- pupil_selection,
128
- output_video_path,
129
- cam_method=CAM_METHODS[-1],
130
- blink_detection=blink_detection,
131
- )
132
- os.remove(video_path)
133
-
134
- num_columns = len(predicted_diameters)
135
-
136
- # Create a layout for the charts
137
- cols = st.columns(num_columns)
138
-
139
- # colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
140
- colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
141
-
142
- # Iterate through categories and assign charts to columns
143
- for i, (category, values) in enumerate(predicted_diameters.items()):
144
- with cols[i]: # Directly use the column index
145
- # st.subheader(category) # Add a subheader for the category
146
-
147
- # Convert values to numeric, replacing non-numeric values with None
148
- values = [convert_diameter(value) for value in values]
149
-
150
- # Create a DataFrame from the values for Altair
151
- df = pd.DataFrame(values, columns=[category])
152
- df["Frame"] = range(1, len(values) + 1) # Create a frame column starting from 1
153
-
154
- # Get the min and max values for y-axis limits, ignoring None
155
- min_value = min(filter(lambda x: x is not None, values), default=None)
156
- max_value = max(filter(lambda x: x is not None, values), default=None)
157
-
158
- # Create an Altair chart with y-axis limits
159
- line_chart = (
160
- alt.Chart(df)
161
- .mark_line(color=colors[i])
162
- .encode(
163
- x=alt.X("Frame:Q", title="Frame Number"),
164
- y=alt.Y(
165
- f"{category}:Q",
166
- title="Diameter",
167
- scale=alt.Scale(domain=[min_value, max_value]),
168
- ),
169
- tooltip=[
170
- "Frame",
171
- alt.Tooltip(f"{category}:Q", title="Diameter"),
172
- ],
173
- )
174
- # .properties(title=f"{category} - Predicted Diameters")
175
- # .configure_axis(grid=True)
176
- )
177
- points_chart = line_chart.mark_point(color=colors[i], filled=True)
178
-
179
- final_chart = (
180
- line_chart.properties(title=f"{category} - Predicted Diameters") + points_chart
181
- ).interactive()
182
-
183
- final_chart = final_chart.configure_axis(grid=True)
184
-
185
- # Display the Altair chart
186
- st.altair_chart(final_chart, use_container_width=True)
187
-
188
- if eyes_ratios is not None and len(eyes_ratios) > 0:
189
- df = pd.DataFrame(eyes_ratios, columns=["EAR"])
190
- df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
191
-
192
- # Create an Altair chart for eyes_ratios
193
- line_chart = (
194
- alt.Chart(df)
195
- .mark_line(color=colors[-1]) # Set color of the line
196
- .encode(
197
- x=alt.X("Frame:Q", title="Frame Number"),
198
- y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
199
- tooltip=["Frame", "EAR"],
200
- )
201
- # .properties(title="Eyes Aspect Ratios (EARs)")
202
- # .configure_axis(grid=True)
203
- )
204
- points_chart = line_chart.mark_point(color=colors[-1], filled=True)
205
-
206
- # Create a horizontal rule at y=0.22
207
- line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
208
-
209
- line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
210
-
211
- # Add text annotations for the lines
212
- text1 = (
213
- alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
214
- .mark_text(align="left", dx=100, dy=9, color="red", size=16)
215
- .encode(y="y:Q", text="label:N")
216
- )
217
-
218
- text2 = (
219
- alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
220
- .mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
221
- .encode(y="y:Q", text="label:N")
222
- )
223
-
224
- # Add gray area text for the region between red and green lines
225
- gray_area_text = (
226
- alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
227
- .mark_text(align="left", dx=0, dy=0, color="gray", size=16)
228
- .encode(y="y:Q", text="label:N")
229
- )
230
-
231
- # Combine all elements: line chart, points, rules, and text annotations
232
- final_chart = (
233
- line_chart.properties(title="Eyes Aspect Ratios (EARs)")
234
- + points_chart
235
- + line1
236
- + line2
237
- + text1
238
- + text2
239
- + gray_area_text
240
- ).interactive()
241
-
242
- # Configure axis properties at the chart level
243
- final_chart = final_chart.configure_axis(grid=True)
244
-
245
- # Display the Altair chart
246
- # st.subheader("Eyes Aspect Ratios (EARs)")
247
- st.altair_chart(final_chart, use_container_width=True)
248
 
249
 
250
  if __name__ == "__main__":
251
  main()
 
 
 
1
  import sys
 
2
  import os.path as osp
 
 
 
 
3
  import streamlit as st
 
 
 
4
 
5
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
6
  sys.path.append(root_path)
7
 
8
  from registry_utils import import_registered_modules
9
  from app_utils import (
 
10
  is_image,
11
  is_video,
12
+ process_image_and_vizualize_data,
13
+ process_video_and_visualize_data,
14
+ set_frames_processed_count_placeholder,
15
+ set_input_image_on_ui,
16
+ set_input_video_on_ui,
17
+ set_page_info,
18
+ set_sidebar_info,
19
  )
20
 
21
  import_registered_modules()
22
 
 
 
 
 
 
 
 
23
 
24
  def main():
25
+ cols = set_page_info()
26
+ uploaded_file, pupil_selection, tv_model, blink_detection = set_sidebar_info()
 
 
 
 
 
 
 
 
27
 
28
  if uploaded_file is not None:
29
  file_extension = uploaded_file.name.split(".")[-1]
30
+ st.session_state["file_extension"] = file_extension
31
 
32
  if is_image(file_extension):
33
+ input_img = set_input_image_on_ui(uploaded_file, cols)
34
+ st.session_state["input_img"] = input_img
 
 
 
 
 
 
35
  elif is_video(file_extension):
36
+ video_frames, video_path = set_input_video_on_ui(uploaded_file, cols)
37
+ st.session_state["video_frames"] = video_frames
38
+ st.session_state["video_path"] = video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ set_frames_processed_count_placeholder(cols)
 
 
41
 
42
  if st.sidebar.button("Predict Diameter & Compute CAM"):
43
  if uploaded_file is None:
44
  st.sidebar.error("Please upload an image or video")
45
  else:
46
  with st.spinner("Analyzing..."):
47
+ if is_image(st.session_state.get("file_extension")):
48
+ input_img = st.session_state.get("input_img")
49
+ process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection)
50
+ elif is_video(st.session_state.get("file_extension")):
51
+ video_frames = st.session_state.get("video_frames")
52
+ video_path = st.session_state.get("video_path")
53
+ process_video_and_visualize_data(
54
+ cols, video_frames, tv_model, pupil_selection, blink_detection, video_path
 
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  if __name__ == "__main__":
59
  main()
60
+ # run: streamlit run app.py --server.enableXsrfProtection false
app_utils.py CHANGED
@@ -1,16 +1,21 @@
1
  import base64
2
  from io import BytesIO
 
3
  import os
4
  import sys
5
  import cv2
6
  from matplotlib import pyplot as plt
7
  import numpy as np
 
8
  import streamlit as st
9
  import torch
10
  import tempfile
11
  from PIL import Image
12
  from torchvision.transforms.functional import to_pil_image
13
  from torchvision import transforms
 
 
 
14
 
15
  from torchcam.methods import CAM
16
  from torchcam import methods as torchcam_methods
@@ -23,6 +28,10 @@ sys.path.append(root_path)
23
  from preprocessing.dataset_creation import EyeDentityDatasetCreation
24
  from utils import get_model
25
 
 
 
 
 
26
 
27
  @torch.no_grad()
28
  def load_model(model_configs, device="cpu"):
@@ -234,12 +243,12 @@ def process_frames(
234
  )
235
 
236
  preprocess_steps = [
237
- transforms.ToTensor(),
238
  transforms.Resize(
239
  [32, 64],
240
  interpolation=transforms.InterpolationMode.BICUBIC,
241
  antialias=True,
242
  ),
 
243
  ]
244
  preprocess_function = transforms.Compose(preprocess_steps)
245
 
@@ -368,7 +377,11 @@ def process_frames(
368
 
369
  combined_frame = np.vstack((input_img_np, output_img_np, frame))
370
 
371
- video_placeholders[eye_type].image(combined_frame, use_column_width=True)
 
 
 
 
372
 
373
  st.session_state.current_frame = idx + 1
374
  txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
@@ -383,9 +396,9 @@ def process_frames(
383
 
384
 
385
  # Function to display video with autoplay and loop
386
- def display_video_with_autoplay(video_col, video_path):
387
  video_html = f"""
388
- <video width="100%" height="auto" autoplay loop muted>
389
  <source src="data:video/mp4;base64,{video_path}" type="video/mp4">
390
  </video>
391
  """
@@ -458,7 +471,359 @@ def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, out
458
  video_base64 = base64.b64encode(video_bytes).decode("utf-8")
459
 
460
  # Display the combined video
461
- display_video_with_autoplay(video_cols[eye_type], video_base64)
462
 
463
  # Clean up
464
  os.remove(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  from io import BytesIO
3
+ import io
4
  import os
5
  import sys
6
  import cv2
7
  from matplotlib import pyplot as plt
8
  import numpy as np
9
+ import pandas as pd
10
  import streamlit as st
11
  import torch
12
  import tempfile
13
  from PIL import Image
14
  from torchvision.transforms.functional import to_pil_image
15
  from torchvision import transforms
16
+ from PIL import ImageOps
17
+ import altair as alt
18
+
19
 
20
  from torchcam.methods import CAM
21
  from torchcam import methods as torchcam_methods
 
28
  from preprocessing.dataset_creation import EyeDentityDatasetCreation
29
  from utils import get_model
30
 
31
+ CAM_METHODS = ["CAM"]
32
+ # colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
33
+ colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
34
+
35
 
36
  @torch.no_grad()
37
  def load_model(model_configs, device="cpu"):
 
243
  )
244
 
245
  preprocess_steps = [
 
246
  transforms.Resize(
247
  [32, 64],
248
  interpolation=transforms.InterpolationMode.BICUBIC,
249
  antialias=True,
250
  ),
251
+ transforms.ToTensor(),
252
  ]
253
  preprocess_function = transforms.Compose(preprocess_steps)
254
 
 
377
 
378
  combined_frame = np.vstack((input_img_np, output_img_np, frame))
379
 
380
+ img_base64 = pil_image_to_base64(Image.fromarray(combined_frame))
381
+ image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
382
+ video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True)
383
+
384
+ # video_placeholders[eye_type].image(combined_frame, use_column_width=True)
385
 
386
  st.session_state.current_frame = idx + 1
387
  txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
 
396
 
397
 
398
  # Function to display video with autoplay and loop
399
+ def display_video_with_autoplay(video_col, video_path, width):
400
  video_html = f"""
401
+ <video width="{str(width)}%" height="auto" autoplay loop muted>
402
  <source src="data:video/mp4;base64,{video_path}" type="video/mp4">
403
  </video>
404
  """
 
471
  video_base64 = base64.b64encode(video_bytes).decode("utf-8")
472
 
473
  # Display the combined video
474
+ display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50)
475
 
476
  # Clean up
477
  os.remove(output_path)
478
+
479
+
480
+ def set_input_image_on_ui(uploaded_file, cols):
481
+ input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
482
+ # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
483
+ input_img = ImageOps.exif_transpose(input_img)
484
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
485
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
486
+ cols[0].image(input_img, use_column_width=True)
487
+ st.session_state.total_frames = 1
488
+ return input_img
489
+
490
+
491
+ def set_input_video_on_ui(uploaded_file, cols):
492
+ tfile = tempfile.NamedTemporaryFile(delete=False)
493
+ tfile.write(uploaded_file.read())
494
+ video_path = tfile.name
495
+ video_frames = extract_frames(video_path)
496
+ cols[0].video(video_path)
497
+ st.session_state.total_frames = len(video_frames)
498
+ return video_frames, video_path
499
+
500
+
501
+ def set_frames_processed_count_placeholder(cols):
502
+ st.session_state.current_frame = 0
503
+ st.session_state.frame_placeholder = cols[0].empty()
504
+ txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
505
+ st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
506
+
507
+
508
+ def set_page_info():
509
+ st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
510
+ st.title("EyeDentify Playground")
511
+ cols = st.columns((1, 1))
512
+ cols[0].header("Input")
513
+ cols[-1].header("Prediction")
514
+ return cols
515
+
516
+
517
+ def set_sidebar_info():
518
+ LABEL_MAP = ["left_pupil", "right_pupil"]
519
+ TV_MODELS = ["ResNet18", "ResNet50"]
520
+
521
+ st.sidebar.title("Upload Face or Eye")
522
+ uploaded_file = st.sidebar.file_uploader(
523
+ "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
524
+ )
525
+ st.sidebar.title("Setup")
526
+ pupil_selection = st.sidebar.selectbox(
527
+ "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
528
+ )
529
+ tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models")
530
+
531
+ blink_detection = st.sidebar.checkbox("Detect Blinks")
532
+
533
+ st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
534
+
535
+ return (uploaded_file, pupil_selection, tv_model, blink_detection)
536
+
537
+
538
+ def pil_image_to_base64(img):
539
+ """Convert a PIL Image to a base64 encoded string."""
540
+ buffered = io.BytesIO()
541
+ img.save(buffered, format="PNG")
542
+ img_str = base64.b64encode(buffered.getvalue()).decode()
543
+ return img_str
544
+
545
+
546
+ def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection):
547
+ input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
548
+ cols,
549
+ [input_img],
550
+ tv_model,
551
+ pupil_selection,
552
+ cam_method=CAM_METHODS[-1],
553
+ blink_detection=blink_detection,
554
+ )
555
+ # for ff in face_frames:
556
+ # if ff["has_face"]:
557
+ # cols[1].image(face_frames[0]["img"], use_column_width=True)
558
+
559
+ input_frames_keys = input_frames.keys()
560
+ video_cols = cols[1].columns(len(input_frames_keys))
561
+
562
+ for i, eye_type in enumerate(input_frames_keys):
563
+ # Check the pupil_selection and set the width accordingly
564
+ if pupil_selection == "both":
565
+ video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
566
+ else:
567
+ img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1]))
568
+ image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
569
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
570
+
571
+ output_frames_keys = output_frames.keys()
572
+ fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
573
+ for i, eye_type in enumerate(output_frames_keys):
574
+ height, width, c = output_frames[eye_type][0].shape
575
+ frame = np.zeros((height, width, c), dtype=np.uint8)
576
+ text = f"{predicted_diameters[eye_type][0]:.2f}"
577
+ frame = overlay_text_on_frame(frame, text)
578
+
579
+ if pupil_selection == "both":
580
+ video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
581
+ video_cols[i].image(frame, use_column_width=True)
582
+ else:
583
+ img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1]))
584
+ image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
585
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
586
+ img_base64 = pil_image_to_base64(Image.fromarray(frame))
587
+ image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
588
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
589
+
590
+ return None
591
+
592
+
593
+ def plot_ears(eyes_ratios, eyes_df):
594
+ eyes_df["EAR"] = eyes_ratios
595
+ df = pd.DataFrame(eyes_ratios, columns=["EAR"])
596
+ df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
597
+
598
+ # Create an Altair chart for eyes_ratios
599
+ line_chart = (
600
+ alt.Chart(df)
601
+ .mark_line(color=colors[-1]) # Set color of the line
602
+ .encode(
603
+ x=alt.X("Frame:Q", title="Frame Number"),
604
+ y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
605
+ tooltip=["Frame", "EAR"],
606
+ )
607
+ # .properties(title="Eyes Aspect Ratios (EARs)")
608
+ # .configure_axis(grid=True)
609
+ )
610
+ points_chart = line_chart.mark_point(color=colors[-1], filled=True)
611
+
612
+ # Create a horizontal rule at y=0.22
613
+ line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
614
+
615
+ line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
616
+
617
+ # Add text annotations for the lines
618
+ text1 = (
619
+ alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
620
+ .mark_text(align="left", dx=100, dy=9, color="red", size=16)
621
+ .encode(y="y:Q", text="label:N")
622
+ )
623
+
624
+ text2 = (
625
+ alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
626
+ .mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
627
+ .encode(y="y:Q", text="label:N")
628
+ )
629
+
630
+ # Add gray area text for the region between red and green lines
631
+ gray_area_text = (
632
+ alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
633
+ .mark_text(align="left", dx=0, dy=0, color="gray", size=16)
634
+ .encode(y="y:Q", text="label:N")
635
+ )
636
+
637
+ # Combine all elements: line chart, points, rules, and text annotations
638
+ final_chart = (
639
+ line_chart.properties(title="Eyes Aspect Ratios (EARs)")
640
+ + points_chart
641
+ + line1
642
+ + line2
643
+ + text1
644
+ + text2
645
+ + gray_area_text
646
+ ).interactive()
647
+
648
+ # Configure axis properties at the chart level
649
+ final_chart = final_chart.configure_axis(grid=True)
650
+
651
+ # Display the Altair chart
652
+ # st.subheader("Eyes Aspect Ratios (EARs)")
653
+ st.altair_chart(final_chart, use_container_width=True)
654
+ return eyes_df
655
+
656
+
657
+ def plot_individual_charts(predicted_diameters, cols):
658
+ # Iterate through categories and assign charts to columns
659
+ for i, (category, values) in enumerate(predicted_diameters.items()):
660
+ with cols[i]: # Directly use the column index
661
+ # st.subheader(category) # Add a subheader for the category
662
+ if "left" in category:
663
+ selected_color = colors[0]
664
+ elif "right" in category:
665
+ selected_color = colors[1]
666
+ else:
667
+ selected_color = colors[i]
668
+
669
+ # Convert values to numeric, replacing non-numeric values with None
670
+ values = [convert_diameter(value) for value in values]
671
+
672
+ if "left" in category:
673
+ category_name = "Left Pupil Diameter"
674
+ else:
675
+ category_name = "Right Pupil Diameter"
676
+
677
+ # Create a DataFrame from the values for Altair
678
+ df = pd.DataFrame(
679
+ {
680
+ "Frame": range(1, len(values) + 1),
681
+ category_name: values,
682
+ }
683
+ )
684
+
685
+ # Get the min and max values for y-axis limits, ignoring None
686
+ min_value = min(filter(lambda x: x is not None, values), default=None)
687
+ max_value = max(filter(lambda x: x is not None, values), default=None)
688
+
689
+ # Create an Altair chart with y-axis limits
690
+ line_chart = (
691
+ alt.Chart(df)
692
+ .mark_line(color=selected_color)
693
+ .encode(
694
+ x=alt.X("Frame:Q", title="Frame Number"),
695
+ y=alt.Y(
696
+ f"{category_name}:Q",
697
+ title="Diameter",
698
+ scale=alt.Scale(domain=[min_value, max_value]),
699
+ ),
700
+ tooltip=[
701
+ "Frame",
702
+ alt.Tooltip(f"{category_name}:Q", title="Diameter"),
703
+ ],
704
+ )
705
+ # .properties(title=f"{category} - Predicted Diameters")
706
+ # .configure_axis(grid=True)
707
+ )
708
+ points_chart = line_chart.mark_point(color=selected_color, filled=True)
709
+
710
+ final_chart = (
711
+ line_chart.properties(
712
+ title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters"
713
+ )
714
+ + points_chart
715
+ ).interactive()
716
+
717
+ final_chart = final_chart.configure_axis(grid=True)
718
+
719
+ # Display the Altair chart
720
+ st.altair_chart(final_chart, use_container_width=True)
721
+ return df
722
+
723
+
724
+ def plot_combined_charts(predicted_diameters):
725
+ all_min_values = []
726
+ all_max_values = []
727
+
728
+ # Create an empty DataFrame to store combined data for plotting
729
+ combined_df = pd.DataFrame()
730
+
731
+ # Iterate through categories and collect data
732
+ for category, values in predicted_diameters.items():
733
+ # Convert values to numeric, replacing non-numeric values with None
734
+ values = [convert_diameter(value) for value in values]
735
+
736
+ # Get the min and max values for y-axis limits, ignoring None
737
+ min_value = min(filter(lambda x: x is not None, values), default=None)
738
+ max_value = max(filter(lambda x: x is not None, values), default=None)
739
+
740
+ all_min_values.append(min_value)
741
+ all_max_values.append(max_value)
742
+
743
+ category = "left_pupil" if "left" in category else "right_pupil"
744
+
745
+ # Create a DataFrame from the values
746
+ df = pd.DataFrame(
747
+ {
748
+ "Diameter": values,
749
+ "Frame": range(1, len(values) + 1), # Create a frame column starting from 1
750
+ "Category": category, # Add a column to specify the category
751
+ }
752
+ )
753
+
754
+ # Append to combined DataFrame
755
+ combined_df = pd.concat([combined_df, df], ignore_index=True)
756
+
757
+ combined_chart = (
758
+ alt.Chart(combined_df)
759
+ .mark_line()
760
+ .encode(
761
+ x=alt.X("Frame:Q", title="Frame Number"),
762
+ y=alt.Y(
763
+ "Diameter:Q",
764
+ title="Diameter",
765
+ scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]),
766
+ ),
767
+ color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"),
768
+ tooltip=["Frame", "Diameter:Q", "Category:N"],
769
+ )
770
+ )
771
+ points_chart = combined_chart.mark_point(filled=True)
772
+
773
+ final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive()
774
+
775
+ final_chart = final_chart.configure_axis(grid=True)
776
+
777
+ # Display the combined chart
778
+ st.altair_chart(final_chart, use_container_width=True)
779
+
780
+ # --------------------------------------------
781
+ # Convert to a DataFrame
782
+ left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]]
783
+ right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]]
784
+
785
+ df = pd.DataFrame(
786
+ {
787
+ "Frame": range(1, len(left_pupil_values) + 1),
788
+ "Left Pupil Diameter": left_pupil_values,
789
+ "Right Pupil Diameter": right_pupil_values,
790
+ }
791
+ )
792
+
793
+ # Calculate the difference between left and right pupil diameters
794
+ df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"]
795
+
796
+ # Determine the status of the difference
797
+ df["Difference Status"] = df.apply(
798
+ lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R",
799
+ axis=1,
800
+ )
801
+
802
+ return df
803
+
804
+
805
+ def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path):
806
+ output_video_path = f"{root_path}/tmp.webm"
807
+ input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
808
+ cols,
809
+ video_frames,
810
+ tv_model,
811
+ pupil_selection,
812
+ output_video_path,
813
+ cam_method=CAM_METHODS[-1],
814
+ blink_detection=blink_detection,
815
+ )
816
+ os.remove(video_path)
817
+
818
+ num_columns = len(predicted_diameters)
819
+ cols = st.columns(num_columns)
820
+
821
+ if num_columns == 2:
822
+ df = plot_combined_charts(predicted_diameters)
823
+ else:
824
+ df = plot_individual_charts(predicted_diameters, cols)
825
+
826
+ if eyes_ratios is not None and len(eyes_ratios) > 0:
827
+ df = plot_ears(eyes_ratios, df)
828
+
829
+ st.dataframe(df, hide_index=True, use_container_width=True)
requirements.txt CHANGED
@@ -20,7 +20,7 @@ dlib
20
  einops
21
  transformers
22
  gfpgan
23
- # streamlit
24
  mediapipe
25
  imutils
26
  scipy
 
20
  einops
21
  transformers
22
  gfpgan
23
+ streamlit==1.38.0
24
  mediapipe
25
  imutils
26
  scipy