Spaces:
Running
Running
vijul.shah
commited on
Commit
•
99cd14f
1
Parent(s):
3733e70
Can download data as csv file and some UI changes
Browse files- app.py +25 -216
- app_utils.py +370 -5
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,251 +1,60 @@
|
|
1 |
-
import os
|
2 |
import sys
|
3 |
-
import tempfile
|
4 |
import os.path as osp
|
5 |
-
from PIL import Image
|
6 |
-
from io import BytesIO
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
import streamlit as st
|
10 |
-
from PIL import ImageOps
|
11 |
-
from matplotlib import pyplot as plt
|
12 |
-
import altair as alt
|
13 |
|
14 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
15 |
sys.path.append(root_path)
|
16 |
|
17 |
from registry_utils import import_registered_modules
|
18 |
from app_utils import (
|
19 |
-
extract_frames,
|
20 |
is_image,
|
21 |
is_video,
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
)
|
28 |
|
29 |
import_registered_modules()
|
30 |
|
31 |
-
CAM_METHODS = ["CAM"]
|
32 |
-
TV_MODELS = ["ResNet18", "ResNet50"]
|
33 |
-
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
|
34 |
-
UPSCALE = [2, 4]
|
35 |
-
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
|
36 |
-
LABEL_MAP = ["left_pupil", "right_pupil"]
|
37 |
-
|
38 |
|
39 |
def main():
|
40 |
-
|
41 |
-
|
42 |
-
cols = st.columns((1, 1))
|
43 |
-
cols[0].header("Input")
|
44 |
-
cols[-1].header("Prediction")
|
45 |
-
|
46 |
-
st.sidebar.title("Upload Face or Eye")
|
47 |
-
uploaded_file = st.sidebar.file_uploader(
|
48 |
-
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
|
49 |
-
)
|
50 |
|
51 |
if uploaded_file is not None:
|
52 |
file_extension = uploaded_file.name.split(".")[-1]
|
|
|
53 |
|
54 |
if is_image(file_extension):
|
55 |
-
input_img =
|
56 |
-
|
57 |
-
input_img = ImageOps.exif_transpose(input_img)
|
58 |
-
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
59 |
-
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
60 |
-
cols[0].image(input_img, use_column_width=True)
|
61 |
-
st.session_state.total_frames = 1
|
62 |
-
|
63 |
elif is_video(file_extension):
|
64 |
-
|
65 |
-
|
66 |
-
video_path =
|
67 |
-
video_frames = extract_frames(video_path)
|
68 |
-
cols[0].video(video_path)
|
69 |
-
st.session_state.total_frames = len(video_frames)
|
70 |
-
|
71 |
-
st.session_state.current_frame = 0
|
72 |
-
st.session_state.frame_placeholder = cols[0].empty()
|
73 |
-
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
74 |
-
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
75 |
-
|
76 |
-
st.sidebar.title("Setup")
|
77 |
-
pupil_selection = st.sidebar.selectbox(
|
78 |
-
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
|
79 |
-
)
|
80 |
-
tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models")
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
|
85 |
|
86 |
if st.sidebar.button("Predict Diameter & Compute CAM"):
|
87 |
if uploaded_file is None:
|
88 |
st.sidebar.error("Please upload an image or video")
|
89 |
else:
|
90 |
with st.spinner("Analyzing..."):
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
blink_detection=blink_detection,
|
100 |
)
|
101 |
-
# for ff in face_frames:
|
102 |
-
# if ff["has_face"]:
|
103 |
-
# cols[1].image(face_frames[0]["img"], use_column_width=True)
|
104 |
-
|
105 |
-
input_frames_keys = input_frames.keys()
|
106 |
-
video_cols = cols[1].columns(len(input_frames_keys))
|
107 |
-
for i, eye_type in enumerate(input_frames_keys):
|
108 |
-
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
|
109 |
-
|
110 |
-
output_frames_keys = output_frames.keys()
|
111 |
-
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
|
112 |
-
for i, eye_type in enumerate(output_frames_keys):
|
113 |
-
height, width, c = output_frames[eye_type][0].shape
|
114 |
-
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
|
115 |
-
|
116 |
-
frame = np.zeros((height, width, c), dtype=np.uint8)
|
117 |
-
text = f"{predicted_diameters[eye_type][0]:.2f}"
|
118 |
-
frame = overlay_text_on_frame(frame, text)
|
119 |
-
video_cols[i].image(frame, use_column_width=True)
|
120 |
-
|
121 |
-
elif is_video(file_extension):
|
122 |
-
output_video_path = f"{root_path}/tmp.webm"
|
123 |
-
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
|
124 |
-
cols,
|
125 |
-
video_frames,
|
126 |
-
tv_model,
|
127 |
-
pupil_selection,
|
128 |
-
output_video_path,
|
129 |
-
cam_method=CAM_METHODS[-1],
|
130 |
-
blink_detection=blink_detection,
|
131 |
-
)
|
132 |
-
os.remove(video_path)
|
133 |
-
|
134 |
-
num_columns = len(predicted_diameters)
|
135 |
-
|
136 |
-
# Create a layout for the charts
|
137 |
-
cols = st.columns(num_columns)
|
138 |
-
|
139 |
-
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
|
140 |
-
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
|
141 |
-
|
142 |
-
# Iterate through categories and assign charts to columns
|
143 |
-
for i, (category, values) in enumerate(predicted_diameters.items()):
|
144 |
-
with cols[i]: # Directly use the column index
|
145 |
-
# st.subheader(category) # Add a subheader for the category
|
146 |
-
|
147 |
-
# Convert values to numeric, replacing non-numeric values with None
|
148 |
-
values = [convert_diameter(value) for value in values]
|
149 |
-
|
150 |
-
# Create a DataFrame from the values for Altair
|
151 |
-
df = pd.DataFrame(values, columns=[category])
|
152 |
-
df["Frame"] = range(1, len(values) + 1) # Create a frame column starting from 1
|
153 |
-
|
154 |
-
# Get the min and max values for y-axis limits, ignoring None
|
155 |
-
min_value = min(filter(lambda x: x is not None, values), default=None)
|
156 |
-
max_value = max(filter(lambda x: x is not None, values), default=None)
|
157 |
-
|
158 |
-
# Create an Altair chart with y-axis limits
|
159 |
-
line_chart = (
|
160 |
-
alt.Chart(df)
|
161 |
-
.mark_line(color=colors[i])
|
162 |
-
.encode(
|
163 |
-
x=alt.X("Frame:Q", title="Frame Number"),
|
164 |
-
y=alt.Y(
|
165 |
-
f"{category}:Q",
|
166 |
-
title="Diameter",
|
167 |
-
scale=alt.Scale(domain=[min_value, max_value]),
|
168 |
-
),
|
169 |
-
tooltip=[
|
170 |
-
"Frame",
|
171 |
-
alt.Tooltip(f"{category}:Q", title="Diameter"),
|
172 |
-
],
|
173 |
-
)
|
174 |
-
# .properties(title=f"{category} - Predicted Diameters")
|
175 |
-
# .configure_axis(grid=True)
|
176 |
-
)
|
177 |
-
points_chart = line_chart.mark_point(color=colors[i], filled=True)
|
178 |
-
|
179 |
-
final_chart = (
|
180 |
-
line_chart.properties(title=f"{category} - Predicted Diameters") + points_chart
|
181 |
-
).interactive()
|
182 |
-
|
183 |
-
final_chart = final_chart.configure_axis(grid=True)
|
184 |
-
|
185 |
-
# Display the Altair chart
|
186 |
-
st.altair_chart(final_chart, use_container_width=True)
|
187 |
-
|
188 |
-
if eyes_ratios is not None and len(eyes_ratios) > 0:
|
189 |
-
df = pd.DataFrame(eyes_ratios, columns=["EAR"])
|
190 |
-
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
|
191 |
-
|
192 |
-
# Create an Altair chart for eyes_ratios
|
193 |
-
line_chart = (
|
194 |
-
alt.Chart(df)
|
195 |
-
.mark_line(color=colors[-1]) # Set color of the line
|
196 |
-
.encode(
|
197 |
-
x=alt.X("Frame:Q", title="Frame Number"),
|
198 |
-
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
|
199 |
-
tooltip=["Frame", "EAR"],
|
200 |
-
)
|
201 |
-
# .properties(title="Eyes Aspect Ratios (EARs)")
|
202 |
-
# .configure_axis(grid=True)
|
203 |
-
)
|
204 |
-
points_chart = line_chart.mark_point(color=colors[-1], filled=True)
|
205 |
-
|
206 |
-
# Create a horizontal rule at y=0.22
|
207 |
-
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
|
208 |
-
|
209 |
-
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
|
210 |
-
|
211 |
-
# Add text annotations for the lines
|
212 |
-
text1 = (
|
213 |
-
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
|
214 |
-
.mark_text(align="left", dx=100, dy=9, color="red", size=16)
|
215 |
-
.encode(y="y:Q", text="label:N")
|
216 |
-
)
|
217 |
-
|
218 |
-
text2 = (
|
219 |
-
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
|
220 |
-
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
|
221 |
-
.encode(y="y:Q", text="label:N")
|
222 |
-
)
|
223 |
-
|
224 |
-
# Add gray area text for the region between red and green lines
|
225 |
-
gray_area_text = (
|
226 |
-
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
|
227 |
-
.mark_text(align="left", dx=0, dy=0, color="gray", size=16)
|
228 |
-
.encode(y="y:Q", text="label:N")
|
229 |
-
)
|
230 |
-
|
231 |
-
# Combine all elements: line chart, points, rules, and text annotations
|
232 |
-
final_chart = (
|
233 |
-
line_chart.properties(title="Eyes Aspect Ratios (EARs)")
|
234 |
-
+ points_chart
|
235 |
-
+ line1
|
236 |
-
+ line2
|
237 |
-
+ text1
|
238 |
-
+ text2
|
239 |
-
+ gray_area_text
|
240 |
-
).interactive()
|
241 |
-
|
242 |
-
# Configure axis properties at the chart level
|
243 |
-
final_chart = final_chart.configure_axis(grid=True)
|
244 |
-
|
245 |
-
# Display the Altair chart
|
246 |
-
# st.subheader("Eyes Aspect Ratios (EARs)")
|
247 |
-
st.altair_chart(final_chart, use_container_width=True)
|
248 |
|
249 |
|
250 |
if __name__ == "__main__":
|
251 |
main()
|
|
|
|
|
|
1 |
import sys
|
|
|
2 |
import os.path as osp
|
|
|
|
|
|
|
|
|
3 |
import streamlit as st
|
|
|
|
|
|
|
4 |
|
5 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
6 |
sys.path.append(root_path)
|
7 |
|
8 |
from registry_utils import import_registered_modules
|
9 |
from app_utils import (
|
|
|
10 |
is_image,
|
11 |
is_video,
|
12 |
+
process_image_and_vizualize_data,
|
13 |
+
process_video_and_visualize_data,
|
14 |
+
set_frames_processed_count_placeholder,
|
15 |
+
set_input_image_on_ui,
|
16 |
+
set_input_video_on_ui,
|
17 |
+
set_page_info,
|
18 |
+
set_sidebar_info,
|
19 |
)
|
20 |
|
21 |
import_registered_modules()
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
def main():
|
25 |
+
cols = set_page_info()
|
26 |
+
uploaded_file, pupil_selection, tv_model, blink_detection = set_sidebar_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
if uploaded_file is not None:
|
29 |
file_extension = uploaded_file.name.split(".")[-1]
|
30 |
+
st.session_state["file_extension"] = file_extension
|
31 |
|
32 |
if is_image(file_extension):
|
33 |
+
input_img = set_input_image_on_ui(uploaded_file, cols)
|
34 |
+
st.session_state["input_img"] = input_img
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
elif is_video(file_extension):
|
36 |
+
video_frames, video_path = set_input_video_on_ui(uploaded_file, cols)
|
37 |
+
st.session_state["video_frames"] = video_frames
|
38 |
+
st.session_state["video_path"] = video_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
set_frames_processed_count_placeholder(cols)
|
|
|
|
|
41 |
|
42 |
if st.sidebar.button("Predict Diameter & Compute CAM"):
|
43 |
if uploaded_file is None:
|
44 |
st.sidebar.error("Please upload an image or video")
|
45 |
else:
|
46 |
with st.spinner("Analyzing..."):
|
47 |
+
if is_image(st.session_state.get("file_extension")):
|
48 |
+
input_img = st.session_state.get("input_img")
|
49 |
+
process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection)
|
50 |
+
elif is_video(st.session_state.get("file_extension")):
|
51 |
+
video_frames = st.session_state.get("video_frames")
|
52 |
+
video_path = st.session_state.get("video_path")
|
53 |
+
process_video_and_visualize_data(
|
54 |
+
cols, video_frames, tv_model, pupil_selection, blink_detection, video_path
|
|
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
main()
|
60 |
+
# run: streamlit run app.py --server.enableXsrfProtection false
|
app_utils.py
CHANGED
@@ -1,16 +1,21 @@
|
|
1 |
import base64
|
2 |
from io import BytesIO
|
|
|
3 |
import os
|
4 |
import sys
|
5 |
import cv2
|
6 |
from matplotlib import pyplot as plt
|
7 |
import numpy as np
|
|
|
8 |
import streamlit as st
|
9 |
import torch
|
10 |
import tempfile
|
11 |
from PIL import Image
|
12 |
from torchvision.transforms.functional import to_pil_image
|
13 |
from torchvision import transforms
|
|
|
|
|
|
|
14 |
|
15 |
from torchcam.methods import CAM
|
16 |
from torchcam import methods as torchcam_methods
|
@@ -23,6 +28,10 @@ sys.path.append(root_path)
|
|
23 |
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
24 |
from utils import get_model
|
25 |
|
|
|
|
|
|
|
|
|
26 |
|
27 |
@torch.no_grad()
|
28 |
def load_model(model_configs, device="cpu"):
|
@@ -234,12 +243,12 @@ def process_frames(
|
|
234 |
)
|
235 |
|
236 |
preprocess_steps = [
|
237 |
-
transforms.ToTensor(),
|
238 |
transforms.Resize(
|
239 |
[32, 64],
|
240 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
241 |
antialias=True,
|
242 |
),
|
|
|
243 |
]
|
244 |
preprocess_function = transforms.Compose(preprocess_steps)
|
245 |
|
@@ -368,7 +377,11 @@ def process_frames(
|
|
368 |
|
369 |
combined_frame = np.vstack((input_img_np, output_img_np, frame))
|
370 |
|
371 |
-
|
|
|
|
|
|
|
|
|
372 |
|
373 |
st.session_state.current_frame = idx + 1
|
374 |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
@@ -383,9 +396,9 @@ def process_frames(
|
|
383 |
|
384 |
|
385 |
# Function to display video with autoplay and loop
|
386 |
-
def display_video_with_autoplay(video_col, video_path):
|
387 |
video_html = f"""
|
388 |
-
<video width="
|
389 |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4">
|
390 |
</video>
|
391 |
"""
|
@@ -458,7 +471,359 @@ def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, out
|
|
458 |
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
459 |
|
460 |
# Display the combined video
|
461 |
-
display_video_with_autoplay(video_cols[eye_type], video_base64)
|
462 |
|
463 |
# Clean up
|
464 |
os.remove(output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
from io import BytesIO
|
3 |
+
import io
|
4 |
import os
|
5 |
import sys
|
6 |
import cv2
|
7 |
from matplotlib import pyplot as plt
|
8 |
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
import streamlit as st
|
11 |
import torch
|
12 |
import tempfile
|
13 |
from PIL import Image
|
14 |
from torchvision.transforms.functional import to_pil_image
|
15 |
from torchvision import transforms
|
16 |
+
from PIL import ImageOps
|
17 |
+
import altair as alt
|
18 |
+
|
19 |
|
20 |
from torchcam.methods import CAM
|
21 |
from torchcam import methods as torchcam_methods
|
|
|
28 |
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
29 |
from utils import get_model
|
30 |
|
31 |
+
CAM_METHODS = ["CAM"]
|
32 |
+
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
|
33 |
+
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
|
34 |
+
|
35 |
|
36 |
@torch.no_grad()
|
37 |
def load_model(model_configs, device="cpu"):
|
|
|
243 |
)
|
244 |
|
245 |
preprocess_steps = [
|
|
|
246 |
transforms.Resize(
|
247 |
[32, 64],
|
248 |
interpolation=transforms.InterpolationMode.BICUBIC,
|
249 |
antialias=True,
|
250 |
),
|
251 |
+
transforms.ToTensor(),
|
252 |
]
|
253 |
preprocess_function = transforms.Compose(preprocess_steps)
|
254 |
|
|
|
377 |
|
378 |
combined_frame = np.vstack((input_img_np, output_img_np, frame))
|
379 |
|
380 |
+
img_base64 = pil_image_to_base64(Image.fromarray(combined_frame))
|
381 |
+
image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
382 |
+
video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True)
|
383 |
+
|
384 |
+
# video_placeholders[eye_type].image(combined_frame, use_column_width=True)
|
385 |
|
386 |
st.session_state.current_frame = idx + 1
|
387 |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
|
|
396 |
|
397 |
|
398 |
# Function to display video with autoplay and loop
|
399 |
+
def display_video_with_autoplay(video_col, video_path, width):
|
400 |
video_html = f"""
|
401 |
+
<video width="{str(width)}%" height="auto" autoplay loop muted>
|
402 |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4">
|
403 |
</video>
|
404 |
"""
|
|
|
471 |
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
472 |
|
473 |
# Display the combined video
|
474 |
+
display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50)
|
475 |
|
476 |
# Clean up
|
477 |
os.remove(output_path)
|
478 |
+
|
479 |
+
|
480 |
+
def set_input_image_on_ui(uploaded_file, cols):
|
481 |
+
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
|
482 |
+
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
|
483 |
+
input_img = ImageOps.exif_transpose(input_img)
|
484 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
485 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
486 |
+
cols[0].image(input_img, use_column_width=True)
|
487 |
+
st.session_state.total_frames = 1
|
488 |
+
return input_img
|
489 |
+
|
490 |
+
|
491 |
+
def set_input_video_on_ui(uploaded_file, cols):
|
492 |
+
tfile = tempfile.NamedTemporaryFile(delete=False)
|
493 |
+
tfile.write(uploaded_file.read())
|
494 |
+
video_path = tfile.name
|
495 |
+
video_frames = extract_frames(video_path)
|
496 |
+
cols[0].video(video_path)
|
497 |
+
st.session_state.total_frames = len(video_frames)
|
498 |
+
return video_frames, video_path
|
499 |
+
|
500 |
+
|
501 |
+
def set_frames_processed_count_placeholder(cols):
|
502 |
+
st.session_state.current_frame = 0
|
503 |
+
st.session_state.frame_placeholder = cols[0].empty()
|
504 |
+
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
505 |
+
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
506 |
+
|
507 |
+
|
508 |
+
def set_page_info():
|
509 |
+
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
|
510 |
+
st.title("EyeDentify Playground")
|
511 |
+
cols = st.columns((1, 1))
|
512 |
+
cols[0].header("Input")
|
513 |
+
cols[-1].header("Prediction")
|
514 |
+
return cols
|
515 |
+
|
516 |
+
|
517 |
+
def set_sidebar_info():
|
518 |
+
LABEL_MAP = ["left_pupil", "right_pupil"]
|
519 |
+
TV_MODELS = ["ResNet18", "ResNet50"]
|
520 |
+
|
521 |
+
st.sidebar.title("Upload Face or Eye")
|
522 |
+
uploaded_file = st.sidebar.file_uploader(
|
523 |
+
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
|
524 |
+
)
|
525 |
+
st.sidebar.title("Setup")
|
526 |
+
pupil_selection = st.sidebar.selectbox(
|
527 |
+
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
|
528 |
+
)
|
529 |
+
tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models")
|
530 |
+
|
531 |
+
blink_detection = st.sidebar.checkbox("Detect Blinks")
|
532 |
+
|
533 |
+
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
|
534 |
+
|
535 |
+
return (uploaded_file, pupil_selection, tv_model, blink_detection)
|
536 |
+
|
537 |
+
|
538 |
+
def pil_image_to_base64(img):
|
539 |
+
"""Convert a PIL Image to a base64 encoded string."""
|
540 |
+
buffered = io.BytesIO()
|
541 |
+
img.save(buffered, format="PNG")
|
542 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
543 |
+
return img_str
|
544 |
+
|
545 |
+
|
546 |
+
def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection):
|
547 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
|
548 |
+
cols,
|
549 |
+
[input_img],
|
550 |
+
tv_model,
|
551 |
+
pupil_selection,
|
552 |
+
cam_method=CAM_METHODS[-1],
|
553 |
+
blink_detection=blink_detection,
|
554 |
+
)
|
555 |
+
# for ff in face_frames:
|
556 |
+
# if ff["has_face"]:
|
557 |
+
# cols[1].image(face_frames[0]["img"], use_column_width=True)
|
558 |
+
|
559 |
+
input_frames_keys = input_frames.keys()
|
560 |
+
video_cols = cols[1].columns(len(input_frames_keys))
|
561 |
+
|
562 |
+
for i, eye_type in enumerate(input_frames_keys):
|
563 |
+
# Check the pupil_selection and set the width accordingly
|
564 |
+
if pupil_selection == "both":
|
565 |
+
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
|
566 |
+
else:
|
567 |
+
img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1]))
|
568 |
+
image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
569 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
570 |
+
|
571 |
+
output_frames_keys = output_frames.keys()
|
572 |
+
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
|
573 |
+
for i, eye_type in enumerate(output_frames_keys):
|
574 |
+
height, width, c = output_frames[eye_type][0].shape
|
575 |
+
frame = np.zeros((height, width, c), dtype=np.uint8)
|
576 |
+
text = f"{predicted_diameters[eye_type][0]:.2f}"
|
577 |
+
frame = overlay_text_on_frame(frame, text)
|
578 |
+
|
579 |
+
if pupil_selection == "both":
|
580 |
+
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
|
581 |
+
video_cols[i].image(frame, use_column_width=True)
|
582 |
+
else:
|
583 |
+
img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1]))
|
584 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
585 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
586 |
+
img_base64 = pil_image_to_base64(Image.fromarray(frame))
|
587 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
588 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
589 |
+
|
590 |
+
return None
|
591 |
+
|
592 |
+
|
593 |
+
def plot_ears(eyes_ratios, eyes_df):
|
594 |
+
eyes_df["EAR"] = eyes_ratios
|
595 |
+
df = pd.DataFrame(eyes_ratios, columns=["EAR"])
|
596 |
+
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
|
597 |
+
|
598 |
+
# Create an Altair chart for eyes_ratios
|
599 |
+
line_chart = (
|
600 |
+
alt.Chart(df)
|
601 |
+
.mark_line(color=colors[-1]) # Set color of the line
|
602 |
+
.encode(
|
603 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
604 |
+
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
|
605 |
+
tooltip=["Frame", "EAR"],
|
606 |
+
)
|
607 |
+
# .properties(title="Eyes Aspect Ratios (EARs)")
|
608 |
+
# .configure_axis(grid=True)
|
609 |
+
)
|
610 |
+
points_chart = line_chart.mark_point(color=colors[-1], filled=True)
|
611 |
+
|
612 |
+
# Create a horizontal rule at y=0.22
|
613 |
+
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
|
614 |
+
|
615 |
+
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
|
616 |
+
|
617 |
+
# Add text annotations for the lines
|
618 |
+
text1 = (
|
619 |
+
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
|
620 |
+
.mark_text(align="left", dx=100, dy=9, color="red", size=16)
|
621 |
+
.encode(y="y:Q", text="label:N")
|
622 |
+
)
|
623 |
+
|
624 |
+
text2 = (
|
625 |
+
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
|
626 |
+
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
|
627 |
+
.encode(y="y:Q", text="label:N")
|
628 |
+
)
|
629 |
+
|
630 |
+
# Add gray area text for the region between red and green lines
|
631 |
+
gray_area_text = (
|
632 |
+
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
|
633 |
+
.mark_text(align="left", dx=0, dy=0, color="gray", size=16)
|
634 |
+
.encode(y="y:Q", text="label:N")
|
635 |
+
)
|
636 |
+
|
637 |
+
# Combine all elements: line chart, points, rules, and text annotations
|
638 |
+
final_chart = (
|
639 |
+
line_chart.properties(title="Eyes Aspect Ratios (EARs)")
|
640 |
+
+ points_chart
|
641 |
+
+ line1
|
642 |
+
+ line2
|
643 |
+
+ text1
|
644 |
+
+ text2
|
645 |
+
+ gray_area_text
|
646 |
+
).interactive()
|
647 |
+
|
648 |
+
# Configure axis properties at the chart level
|
649 |
+
final_chart = final_chart.configure_axis(grid=True)
|
650 |
+
|
651 |
+
# Display the Altair chart
|
652 |
+
# st.subheader("Eyes Aspect Ratios (EARs)")
|
653 |
+
st.altair_chart(final_chart, use_container_width=True)
|
654 |
+
return eyes_df
|
655 |
+
|
656 |
+
|
657 |
+
def plot_individual_charts(predicted_diameters, cols):
|
658 |
+
# Iterate through categories and assign charts to columns
|
659 |
+
for i, (category, values) in enumerate(predicted_diameters.items()):
|
660 |
+
with cols[i]: # Directly use the column index
|
661 |
+
# st.subheader(category) # Add a subheader for the category
|
662 |
+
if "left" in category:
|
663 |
+
selected_color = colors[0]
|
664 |
+
elif "right" in category:
|
665 |
+
selected_color = colors[1]
|
666 |
+
else:
|
667 |
+
selected_color = colors[i]
|
668 |
+
|
669 |
+
# Convert values to numeric, replacing non-numeric values with None
|
670 |
+
values = [convert_diameter(value) for value in values]
|
671 |
+
|
672 |
+
if "left" in category:
|
673 |
+
category_name = "Left Pupil Diameter"
|
674 |
+
else:
|
675 |
+
category_name = "Right Pupil Diameter"
|
676 |
+
|
677 |
+
# Create a DataFrame from the values for Altair
|
678 |
+
df = pd.DataFrame(
|
679 |
+
{
|
680 |
+
"Frame": range(1, len(values) + 1),
|
681 |
+
category_name: values,
|
682 |
+
}
|
683 |
+
)
|
684 |
+
|
685 |
+
# Get the min and max values for y-axis limits, ignoring None
|
686 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
687 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
688 |
+
|
689 |
+
# Create an Altair chart with y-axis limits
|
690 |
+
line_chart = (
|
691 |
+
alt.Chart(df)
|
692 |
+
.mark_line(color=selected_color)
|
693 |
+
.encode(
|
694 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
695 |
+
y=alt.Y(
|
696 |
+
f"{category_name}:Q",
|
697 |
+
title="Diameter",
|
698 |
+
scale=alt.Scale(domain=[min_value, max_value]),
|
699 |
+
),
|
700 |
+
tooltip=[
|
701 |
+
"Frame",
|
702 |
+
alt.Tooltip(f"{category_name}:Q", title="Diameter"),
|
703 |
+
],
|
704 |
+
)
|
705 |
+
# .properties(title=f"{category} - Predicted Diameters")
|
706 |
+
# .configure_axis(grid=True)
|
707 |
+
)
|
708 |
+
points_chart = line_chart.mark_point(color=selected_color, filled=True)
|
709 |
+
|
710 |
+
final_chart = (
|
711 |
+
line_chart.properties(
|
712 |
+
title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters"
|
713 |
+
)
|
714 |
+
+ points_chart
|
715 |
+
).interactive()
|
716 |
+
|
717 |
+
final_chart = final_chart.configure_axis(grid=True)
|
718 |
+
|
719 |
+
# Display the Altair chart
|
720 |
+
st.altair_chart(final_chart, use_container_width=True)
|
721 |
+
return df
|
722 |
+
|
723 |
+
|
724 |
+
def plot_combined_charts(predicted_diameters):
|
725 |
+
all_min_values = []
|
726 |
+
all_max_values = []
|
727 |
+
|
728 |
+
# Create an empty DataFrame to store combined data for plotting
|
729 |
+
combined_df = pd.DataFrame()
|
730 |
+
|
731 |
+
# Iterate through categories and collect data
|
732 |
+
for category, values in predicted_diameters.items():
|
733 |
+
# Convert values to numeric, replacing non-numeric values with None
|
734 |
+
values = [convert_diameter(value) for value in values]
|
735 |
+
|
736 |
+
# Get the min and max values for y-axis limits, ignoring None
|
737 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
738 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
739 |
+
|
740 |
+
all_min_values.append(min_value)
|
741 |
+
all_max_values.append(max_value)
|
742 |
+
|
743 |
+
category = "left_pupil" if "left" in category else "right_pupil"
|
744 |
+
|
745 |
+
# Create a DataFrame from the values
|
746 |
+
df = pd.DataFrame(
|
747 |
+
{
|
748 |
+
"Diameter": values,
|
749 |
+
"Frame": range(1, len(values) + 1), # Create a frame column starting from 1
|
750 |
+
"Category": category, # Add a column to specify the category
|
751 |
+
}
|
752 |
+
)
|
753 |
+
|
754 |
+
# Append to combined DataFrame
|
755 |
+
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
756 |
+
|
757 |
+
combined_chart = (
|
758 |
+
alt.Chart(combined_df)
|
759 |
+
.mark_line()
|
760 |
+
.encode(
|
761 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
762 |
+
y=alt.Y(
|
763 |
+
"Diameter:Q",
|
764 |
+
title="Diameter",
|
765 |
+
scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]),
|
766 |
+
),
|
767 |
+
color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"),
|
768 |
+
tooltip=["Frame", "Diameter:Q", "Category:N"],
|
769 |
+
)
|
770 |
+
)
|
771 |
+
points_chart = combined_chart.mark_point(filled=True)
|
772 |
+
|
773 |
+
final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive()
|
774 |
+
|
775 |
+
final_chart = final_chart.configure_axis(grid=True)
|
776 |
+
|
777 |
+
# Display the combined chart
|
778 |
+
st.altair_chart(final_chart, use_container_width=True)
|
779 |
+
|
780 |
+
# --------------------------------------------
|
781 |
+
# Convert to a DataFrame
|
782 |
+
left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]]
|
783 |
+
right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]]
|
784 |
+
|
785 |
+
df = pd.DataFrame(
|
786 |
+
{
|
787 |
+
"Frame": range(1, len(left_pupil_values) + 1),
|
788 |
+
"Left Pupil Diameter": left_pupil_values,
|
789 |
+
"Right Pupil Diameter": right_pupil_values,
|
790 |
+
}
|
791 |
+
)
|
792 |
+
|
793 |
+
# Calculate the difference between left and right pupil diameters
|
794 |
+
df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"]
|
795 |
+
|
796 |
+
# Determine the status of the difference
|
797 |
+
df["Difference Status"] = df.apply(
|
798 |
+
lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R",
|
799 |
+
axis=1,
|
800 |
+
)
|
801 |
+
|
802 |
+
return df
|
803 |
+
|
804 |
+
|
805 |
+
def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path):
|
806 |
+
output_video_path = f"{root_path}/tmp.webm"
|
807 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
|
808 |
+
cols,
|
809 |
+
video_frames,
|
810 |
+
tv_model,
|
811 |
+
pupil_selection,
|
812 |
+
output_video_path,
|
813 |
+
cam_method=CAM_METHODS[-1],
|
814 |
+
blink_detection=blink_detection,
|
815 |
+
)
|
816 |
+
os.remove(video_path)
|
817 |
+
|
818 |
+
num_columns = len(predicted_diameters)
|
819 |
+
cols = st.columns(num_columns)
|
820 |
+
|
821 |
+
if num_columns == 2:
|
822 |
+
df = plot_combined_charts(predicted_diameters)
|
823 |
+
else:
|
824 |
+
df = plot_individual_charts(predicted_diameters, cols)
|
825 |
+
|
826 |
+
if eyes_ratios is not None and len(eyes_ratios) > 0:
|
827 |
+
df = plot_ears(eyes_ratios, df)
|
828 |
+
|
829 |
+
st.dataframe(df, hide_index=True, use_container_width=True)
|
requirements.txt
CHANGED
@@ -20,7 +20,7 @@ dlib
|
|
20 |
einops
|
21 |
transformers
|
22 |
gfpgan
|
23 |
-
|
24 |
mediapipe
|
25 |
imutils
|
26 |
scipy
|
|
|
20 |
einops
|
21 |
transformers
|
22 |
gfpgan
|
23 |
+
streamlit==1.38.0
|
24 |
mediapipe
|
25 |
imutils
|
26 |
scipy
|