Spaces:
Running
Running
import base64 | |
from io import BytesIO | |
import io | |
import os | |
import sys | |
import cv2 | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
import tempfile | |
from PIL import Image | |
from torchvision.transforms.functional import to_pil_image | |
from torchvision import transforms | |
from PIL import ImageOps | |
import altair as alt | |
import streamlit.components.v1 as components | |
from torchcam.methods import CAM | |
from torchcam import methods as torchcam_methods | |
from torchcam.utils import overlay_mask | |
import os.path as osp | |
root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
sys.path.append(root_path) | |
from preprocessing.dataset_creation import EyeDentityDatasetCreation | |
from utils import get_model | |
CAM_METHODS = ["CAM"] | |
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange | |
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray | |
def load_model(model_configs, device="cpu"): | |
"""Loads the pre-trained model.""" | |
model_path = os.path.join(root_path, model_configs["model_path"]) | |
model_dict = torch.load(model_path, map_location=device) | |
model = get_model(model_configs=model_configs) | |
model.load_state_dict(model_dict) | |
model = model.to(device).eval() | |
return model | |
def extract_frames(video_path): | |
"""Extracts frames from a video file.""" | |
vidcap = cv2.VideoCapture(video_path) | |
frames = [] | |
success, image = vidcap.read() | |
while success: | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
frames.append(image_rgb) | |
success, image = vidcap.read() | |
vidcap.release() | |
return frames | |
def resize_frame(image, max_width=640, max_height=480): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
original_size = image.size | |
# Resize the frame similarly to the image resizing logic | |
if original_size[0] == original_size[1] and original_size[0] >= 256: | |
max_size = (256, 256) | |
else: | |
max_size = list(original_size) | |
if original_size[0] >= max_width: | |
max_size[0] = max_width | |
elif original_size[0] < 64: | |
max_size[0] = 64 | |
if original_size[1] >= max_height: | |
max_size[1] = max_height | |
elif original_size[1] < 32: | |
max_size[1] = 32 | |
image.thumbnail(max_size) | |
# image = image.resize(max_size) | |
return image | |
def is_image(file_extension): | |
"""Checks if the file is an image.""" | |
return file_extension.lower() in ["png", "jpeg", "jpg"] | |
def is_video(file_extension): | |
"""Checks if the file is a video.""" | |
return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"] | |
def get_codec_and_extension(file_format): | |
"""Return codec and file extension based on the format.""" | |
if file_format == "mp4": | |
return "H264", ".mp4" | |
elif file_format == "avi": | |
return "MJPG", ".avi" | |
elif file_format == "webm": | |
return "VP80", ".webm" | |
else: | |
return "MJPG", ".avi" | |
def display_results(input_image, cam_frame, pupil_diameter, cols): | |
"""Displays the input image and overlayed CAM result.""" | |
fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
axs[0].imshow(input_image) | |
axs[0].axis("off") | |
axs[0].set_title("Input Image") | |
axs[1].imshow(cam_frame) | |
axs[1].axis("off") | |
axs[1].set_title("Overlayed CAM") | |
cols[-1].pyplot(fig) | |
cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm") | |
def preprocess_image(input_img, max_size=(256, 256)): | |
"""Resizes and preprocesses an image.""" | |
input_img.thumbnail(max_size) | |
preprocess_steps = [ | |
transforms.ToTensor(), | |
transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), | |
] | |
return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0) | |
def overlay_text_on_frame(frame, text, position=(16, 20)): | |
"""Write text on the image frame using OpenCV.""" | |
return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA) | |
def get_configs(blink_detection=False): | |
upscale = "-" | |
upscale_method_or_model = "-" | |
if upscale == "-": | |
sr_configs = None | |
else: | |
sr_configs = { | |
"method": upscale_method_or_model, | |
"params": {"upscale": upscale}, | |
} | |
config_file = { | |
"sr_configs": sr_configs, | |
"feature_extraction_configs": { | |
"blink_detection": blink_detection, | |
"upscale": upscale, | |
"extraction_library": "mediapipe", | |
}, | |
} | |
return config_file | |
def setup(cols, pupil_selection, tv_model, output_path): | |
left_pupil_model = None | |
left_pupil_cam_extractor = None | |
right_pupil_model = None | |
right_pupil_cam_extractor = None | |
output_frames = {} | |
input_frames = {} | |
predicted_diameters = {} | |
pred_diameters_frames = {} | |
if pupil_selection == "both": | |
selected_eyes = ["left_eye", "right_eye"] | |
elif pupil_selection == "left_pupil": | |
selected_eyes = ["left_eye"] | |
elif pupil_selection == "right_pupil": | |
selected_eyes = ["right_eye"] | |
for i, eye_type in enumerate(selected_eyes): | |
model_configs = { | |
"model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", | |
"registered_model_name": tv_model, | |
"num_classes": 1, | |
} | |
if eye_type == "left_eye": | |
left_pupil_model = load_model(model_configs) | |
left_pupil_cam_extractor = None | |
output_frames[eye_type] = [] | |
input_frames[eye_type] = [] | |
predicted_diameters[eye_type] = [] | |
pred_diameters_frames[eye_type] = [] | |
else: | |
right_pupil_model = load_model(model_configs) | |
right_pupil_cam_extractor = None | |
output_frames[eye_type] = [] | |
input_frames[eye_type] = [] | |
predicted_diameters[eye_type] = [] | |
pred_diameters_frames[eye_type] = [] | |
video_placeholders = {} | |
if output_path: | |
video_cols = cols[1].columns(len(input_frames.keys())) | |
for i, eye_type in enumerate(list(input_frames.keys())): | |
video_placeholders[eye_type] = video_cols[i].empty() | |
return ( | |
selected_eyes, | |
input_frames, | |
output_frames, | |
predicted_diameters, | |
pred_diameters_frames, | |
video_placeholders, | |
left_pupil_model, | |
left_pupil_cam_extractor, | |
right_pupil_model, | |
right_pupil_cam_extractor, | |
) | |
def process_frames( | |
cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False | |
): | |
config_file = get_configs(blink_detection) | |
face_frames = [] | |
( | |
selected_eyes, | |
input_frames, | |
output_frames, | |
predicted_diameters, | |
pred_diameters_frames, | |
video_placeholders, | |
left_pupil_model, | |
left_pupil_cam_extractor, | |
right_pupil_model, | |
right_pupil_cam_extractor, | |
) = setup(cols, pupil_selection, tv_model, output_path) | |
ds_creation = EyeDentityDatasetCreation( | |
feature_extraction_configs=config_file["feature_extraction_configs"], | |
sr_configs=config_file["sr_configs"], | |
) | |
preprocess_steps = [ | |
transforms.Resize( | |
[32, 64], | |
interpolation=transforms.InterpolationMode.BICUBIC, | |
antialias=True, | |
), | |
transforms.ToTensor(), | |
] | |
preprocess_function = transforms.Compose(preprocess_steps) | |
eyes_ratios = [] | |
for idx, input_img in enumerate(input_imgs): | |
img = np.array(input_img) | |
ds_results = ds_creation(img) | |
left_eye = None | |
right_eye = None | |
blinked = False | |
eyes_ratio = None | |
if ds_results is not None and "face" in ds_results: | |
face_img = to_pil_image(ds_results["face"]) | |
has_face = True | |
else: | |
face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8)) | |
has_face = False | |
face_frames.append({"has_face": has_face, "img": face_img}) | |
if ds_results is not None and "eyes" in ds_results.keys(): | |
blinked = ds_results["eyes"]["blinked"] | |
eyes_ratio = ds_results["eyes"]["eyes_ratio"] | |
if eyes_ratio is not None: | |
eyes_ratios.append(eyes_ratio) | |
if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None: | |
left_eye = ds_results["eyes"]["left_eye"] | |
left_eye = to_pil_image(left_eye).convert("RGB") | |
left_eye = preprocess_function(left_eye) | |
left_eye = left_eye.unsqueeze(0) | |
if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None: | |
right_eye = ds_results["eyes"]["right_eye"] | |
right_eye = to_pil_image(right_eye).convert("RGB") | |
right_eye = preprocess_function(right_eye) | |
right_eye = right_eye.unsqueeze(0) | |
else: | |
input_img = preprocess_function(input_img) | |
input_img = input_img.unsqueeze(0) | |
if pupil_selection == "left_pupil": | |
left_eye = input_img | |
elif pupil_selection == "right_pupil": | |
right_eye = input_img | |
else: | |
left_eye = input_img | |
right_eye = input_img | |
for i, eye_type in enumerate(selected_eyes): | |
if blinked: | |
if left_eye is not None and eye_type == "left_eye": | |
_, height, width = left_eye.squeeze(0).shape | |
input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
elif right_eye is not None and eye_type == "right_eye": | |
_, height, width = right_eye.squeeze(0).shape | |
input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
input_img_np = np.array(input_image_pil) | |
zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8)) | |
output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink") | |
predicted_diameter = "blink" | |
else: | |
if left_eye is not None and eye_type == "left_eye": | |
if left_pupil_cam_extractor is None: | |
if tv_model == "ResNet18": | |
target_layer = left_pupil_model.resnet.layer4[-1].conv2 | |
elif tv_model == "ResNet50": | |
target_layer = left_pupil_model.resnet.layer4[-1].conv3 | |
else: | |
raise Exception(f"No target layer available for selected model: {tv_model}") | |
left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
left_pupil_model, | |
target_layer=target_layer, | |
fc_layer=left_pupil_model.resnet.fc, | |
input_shape=left_eye.shape, | |
) | |
output = left_pupil_model(left_eye) | |
predicted_diameter = output[0].item() | |
act_maps = left_pupil_cam_extractor(0, output) | |
activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps) | |
input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
elif right_eye is not None and eye_type == "right_eye": | |
if right_pupil_cam_extractor is None: | |
if tv_model == "ResNet18": | |
target_layer = right_pupil_model.resnet.layer4[-1].conv2 | |
elif tv_model == "ResNet50": | |
target_layer = right_pupil_model.resnet.layer4[-1].conv3 | |
else: | |
raise Exception(f"No target layer available for selected model: {tv_model}") | |
right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
right_pupil_model, | |
target_layer=target_layer, | |
fc_layer=right_pupil_model.resnet.fc, | |
input_shape=right_eye.shape, | |
) | |
output = right_pupil_model(right_eye) | |
predicted_diameter = output[0].item() | |
act_maps = right_pupil_cam_extractor(0, output) | |
activation_map = ( | |
act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps) | |
) | |
input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
# Create CAM overlay | |
activation_map_pil = to_pil_image(activation_map, mode="F") | |
result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5) | |
input_img_np = np.array(input_image_pil) | |
output_img_np = np.array(result) | |
# Add frame and predicted diameter to lists | |
input_frames[eye_type].append(input_img_np) | |
output_frames[eye_type].append(output_img_np) | |
predicted_diameters[eye_type].append(predicted_diameter) | |
if output_path: | |
height, width, _ = output_img_np.shape | |
frame = np.zeros((height, width, 3), dtype=np.uint8) | |
if not isinstance(predicted_diameter, str): | |
text = f"{predicted_diameter:.2f}" | |
else: | |
text = predicted_diameter | |
frame = overlay_text_on_frame(frame, text) | |
pred_diameters_frames[eye_type].append(frame) | |
combined_frame = np.vstack((input_img_np, output_img_np, frame)) | |
img_base64 = pil_image_to_base64(Image.fromarray(combined_frame)) | |
image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True) | |
# video_placeholders[eye_type].image(combined_frame, use_column_width=True) | |
st.session_state.current_frame = idx + 1 | |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
if output_path: | |
combine_and_show_frames( | |
input_frames, output_frames, pred_diameters_frames, output_path, codec, video_placeholders | |
) | |
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
# Function to display video with autoplay and loop | |
def display_video_with_autoplay(video_col, video_path, width): | |
video_html = f""" | |
<video width="{str(width)}%" height="auto" autoplay loop muted> | |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4"> | |
</video> | |
""" | |
video_col.markdown(video_html, unsafe_allow_html=True) | |
def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False): | |
resized_frames = [] | |
for i, frame in enumerate(video_frames): | |
input_img = resize_frame(frame, max_width=640, max_height=480) | |
resized_frames.append(input_img) | |
file_format = output_path.split(".")[-1] | |
codec, extension = get_codec_and_extension(file_format) | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection | |
) | |
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
# Function to convert string values to float or None | |
def convert_diameter(value): | |
try: | |
return float(value) | |
except (ValueError, TypeError): | |
return None # Return None if conversion fails | |
def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, output_path, codec, video_cols): | |
# Assuming all frames have the same keys (eye types) | |
eye_types = input_frames.keys() | |
for i, eye_type in enumerate(eye_types): | |
in_frames = input_frames[eye_type] | |
cam_out_frames = cam_frames[eye_type] | |
pred_diameters_text_frames = pred_diameters_frames[eye_type] | |
# Get frame properties (assuming all frames have the same dimensions) | |
height, width, _ = in_frames[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*codec) | |
fps = 10.0 | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height * 3)) # Width is tripled for concatenation | |
# Loop through each set of frames and concatenate them | |
for j in range(len(in_frames)): | |
input_frame = in_frames[j] | |
cam_frame = cam_out_frames[j] | |
pred_frame = pred_diameters_text_frames[j] | |
# Convert frames to BGR if necessary | |
input_frame_bgr = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR) | |
cam_frame_bgr = cv2.cvtColor(cam_frame, cv2.COLOR_RGB2BGR) | |
pred_frame_bgr = cv2.cvtColor(pred_frame, cv2.COLOR_RGB2BGR) | |
# Concatenate frames horizontally (input, cam, pred) | |
combined_frame = np.vstack((input_frame_bgr, cam_frame_bgr, pred_frame_bgr)) | |
# Write the combined frame to the video | |
out.write(combined_frame) | |
# Release the video writer | |
out.release() | |
# Read the video and encode it in base64 for displaying | |
with open(output_path, "rb") as video_file: | |
video_bytes = video_file.read() | |
video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
# Display the combined video | |
display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50) | |
# Clean up | |
os.remove(output_path) | |
def set_input_image_on_ui(uploaded_file, cols): | |
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB") | |
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image. | |
input_img = ImageOps.exif_transpose(input_img) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
cols[0].image(input_img, use_column_width=True) | |
st.session_state.total_frames = 1 | |
return input_img | |
def set_input_video_on_ui(uploaded_file, cols): | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
try: | |
tfile.write(uploaded_file.read()) | |
except Exception: | |
tfile.write(uploaded_file) | |
video_path = tfile.name | |
video_frames = extract_frames(video_path) | |
cols[0].video(video_path) | |
st.session_state.total_frames = len(video_frames) | |
return video_frames, video_path | |
def set_frames_processed_count_placeholder(cols): | |
st.session_state.current_frame = 0 | |
st.session_state.frame_placeholder = cols[0].empty() | |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
def video_to_bytes(video_path): | |
# Open the video file in binary mode and return the bytes | |
with open(video_path, "rb") as video_file: | |
return video_file.read() | |
def display_video_library(video_folder="./sample_videos"): | |
# Get all video files from the folder | |
video_files = [f for f in os.listdir(video_folder) if f.endswith(".webm")] | |
# Store the selected video path | |
selected_video_path = None | |
# Calculate number of columns (adjust based on your layout preferences) | |
num_columns = 3 # For a grid of 3 videos per row | |
# Display videos in a grid layout with 'Select' button for each video | |
for i in range(0, len(video_files), num_columns): | |
cols = st.columns(num_columns) | |
for idx, video_file in enumerate(video_files[i : i + num_columns]): | |
with cols[idx]: | |
st.subheader(video_file.split(".")[0]) # Use the file name as the title | |
video_path = os.path.join(video_folder, video_file) | |
st.video(video_path) # Show the video | |
if st.button(f"Select {video_file.split('.')[0]}", key=video_file, type="primary"): | |
st.session_state.clear() | |
st.toast("Scroll Down to see the input and predictions", icon="⏬") | |
selected_video_path = video_path # Store the path of the selected video | |
return selected_video_path | |
def set_page_info_and_sidebar_info(): | |
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") | |
st.title("👁️ PupilSense 👁️🕵️♂️") | |
# st.markdown("Upload your own images or video **OR** select from our sample library below") | |
st.markdown( | |
"<p style='font-size: 30px;'>" | |
"Upload your own image 🖼️ or video 🎞️ <strong>OR</strong> select from our sample videos 📚" | |
"</p>", | |
unsafe_allow_html=True, | |
) | |
# video_path = display_video_library() | |
show_demo_videos = st.sidebar.checkbox("Show Sample Videos", value=False) | |
if show_demo_videos: | |
video_path = display_video_library() | |
else: | |
video_path = None | |
st.markdown("<hr id='target_element' style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True) | |
cols = st.columns((1, 1)) | |
cols[0].header("Input") | |
cols[-1].header("Prediction") | |
st.markdown("<hr style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True) | |
LABEL_MAP = ["left_pupil", "right_pupil"] | |
TV_MODELS = ["ResNet18", "ResNet50"] | |
if "uploader_key" not in st.session_state: | |
st.session_state["uploader_key"] = 1 | |
st.sidebar.title("Upload Face 👨🦱 or Eye 👁️") | |
uploaded_file = st.sidebar.file_uploader( | |
"Upload Image or Video", | |
type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"], | |
key=st.session_state["uploader_key"], | |
) | |
if uploaded_file is not None: | |
st.session_state["uploaded_file"] = uploaded_file | |
st.sidebar.title("Setup") | |
pupil_selection = st.sidebar.selectbox( | |
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation" | |
) | |
tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models") | |
blink_detection = st.sidebar.checkbox("Detect Blinks", value=True) | |
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True) | |
if "uploaded_file" not in st.session_state: | |
st.session_state["uploaded_file"] = None | |
if "og_video_path" not in st.session_state: | |
st.session_state["og_video_path"] = None | |
if uploaded_file is None and video_path is not None: | |
video_bytes = video_to_bytes(video_path) | |
uploaded_file = video_bytes | |
st.session_state["uploaded_file"] = uploaded_file | |
st.session_state["og_video_path"] = video_path | |
st.session_state["uploader_key"] = 0 | |
return ( | |
cols, | |
st.session_state["og_video_path"], | |
st.session_state["uploaded_file"], | |
pupil_selection, | |
tv_model, | |
blink_detection, | |
) | |
def pil_image_to_base64(img): | |
"""Convert a PIL Image to a base64 encoded string.""" | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return img_str | |
def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection): | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
cols, | |
[input_img], | |
tv_model, | |
pupil_selection, | |
cam_method=CAM_METHODS[-1], | |
blink_detection=blink_detection, | |
) | |
# for ff in face_frames: | |
# if ff["has_face"]: | |
# cols[1].image(face_frames[0]["img"], use_column_width=True) | |
input_frames_keys = input_frames.keys() | |
video_cols = cols[1].columns(len(input_frames_keys)) | |
for i, eye_type in enumerate(input_frames_keys): | |
# Check the pupil_selection and set the width accordingly | |
if pupil_selection == "both": | |
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True) | |
else: | |
img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1])) | |
image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
output_frames_keys = output_frames.keys() | |
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5)) | |
for i, eye_type in enumerate(output_frames_keys): | |
height, width, c = output_frames[eye_type][0].shape | |
frame = np.zeros((height, width, c), dtype=np.uint8) | |
text = f"{predicted_diameters[eye_type][0]:.2f}" | |
frame = overlay_text_on_frame(frame, text) | |
if pupil_selection == "both": | |
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True) | |
video_cols[i].image(frame, use_column_width=True) | |
else: | |
img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1])) | |
image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
img_base64 = pil_image_to_base64(Image.fromarray(frame)) | |
image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>' | |
video_cols[i].markdown(image_html, unsafe_allow_html=True) | |
return None | |
def plot_ears(eyes_ratios, eyes_df): | |
eyes_df["EAR"] = eyes_ratios | |
df = pd.DataFrame(eyes_ratios, columns=["EAR"]) | |
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1 | |
# Create an Altair chart for eyes_ratios | |
line_chart = ( | |
alt.Chart(df) | |
.mark_line(color=colors[-1]) # Set color of the line | |
.encode( | |
x=alt.X("Frame:Q", title="Frame Number"), | |
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"), | |
tooltip=["Frame", "EAR"], | |
) | |
# .properties(title="Eyes Aspect Ratios (EARs)") | |
# .configure_axis(grid=True) | |
) | |
points_chart = line_chart.mark_point(color=colors[-1], filled=True) | |
# Create a horizontal rule at y=0.22 | |
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q") | |
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q") | |
# Add text annotations for the lines | |
text1 = ( | |
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]})) | |
.mark_text(align="left", dx=100, dy=9, color="red", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
text2 = ( | |
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]})) | |
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
# Add gray area text for the region between red and green lines | |
gray_area_text = ( | |
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]})) | |
.mark_text(align="left", dx=0, dy=0, color="gray", size=16) | |
.encode(y="y:Q", text="label:N") | |
) | |
# Combine all elements: line chart, points, rules, and text annotations | |
final_chart = ( | |
line_chart.properties(title="Eyes Aspect Ratios (EARs)") | |
+ points_chart | |
+ line1 | |
+ line2 | |
+ text1 | |
+ text2 | |
+ gray_area_text | |
).interactive() | |
# Configure axis properties at the chart level | |
final_chart = final_chart.configure_axis(grid=True) | |
# Display the Altair chart | |
# st.subheader("Eyes Aspect Ratios (EARs)") | |
st.altair_chart(final_chart, use_container_width=True) | |
return eyes_df | |
def plot_individual_charts(predicted_diameters, cols): | |
# Iterate through categories and assign charts to columns | |
for i, (category, values) in enumerate(predicted_diameters.items()): | |
with cols[i]: # Directly use the column index | |
# st.subheader(category) # Add a subheader for the category | |
if "left" in category: | |
selected_color = colors[0] | |
elif "right" in category: | |
selected_color = colors[1] | |
else: | |
selected_color = colors[i] | |
# Convert values to numeric, replacing non-numeric values with None | |
values = [convert_diameter(value) for value in values] | |
if "left" in category: | |
category_name = "Left Pupil Diameter" | |
else: | |
category_name = "Right Pupil Diameter" | |
# Create a DataFrame from the values for Altair | |
df = pd.DataFrame( | |
{ | |
"Frame": range(1, len(values) + 1), | |
category_name: values, | |
} | |
) | |
# Get the min and max values for y-axis limits, ignoring None | |
min_value = min(filter(lambda x: x is not None, values), default=None) | |
max_value = max(filter(lambda x: x is not None, values), default=None) | |
# Create an Altair chart with y-axis limits | |
line_chart = ( | |
alt.Chart(df) | |
.mark_line(color=selected_color) | |
.encode( | |
x=alt.X("Frame:Q", title="Frame Number"), | |
y=alt.Y( | |
f"{category_name}:Q", | |
title="Diameter", | |
scale=alt.Scale(domain=[min_value, max_value]), | |
), | |
tooltip=[ | |
"Frame", | |
alt.Tooltip(f"{category_name}:Q", title="Diameter"), | |
], | |
) | |
# .properties(title=f"{category} - Predicted Diameters") | |
# .configure_axis(grid=True) | |
) | |
points_chart = line_chart.mark_point(color=selected_color, filled=True) | |
final_chart = ( | |
line_chart.properties( | |
title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters" | |
) | |
+ points_chart | |
).interactive() | |
final_chart = final_chart.configure_axis(grid=True) | |
# Display the Altair chart | |
st.altair_chart(final_chart, use_container_width=True) | |
return df | |
def plot_combined_charts(predicted_diameters): | |
all_min_values = [] | |
all_max_values = [] | |
# Create an empty DataFrame to store combined data for plotting | |
combined_df = pd.DataFrame() | |
# Iterate through categories and collect data | |
for category, values in predicted_diameters.items(): | |
# Convert values to numeric, replacing non-numeric values with None | |
values = [convert_diameter(value) for value in values] | |
# Get the min and max values for y-axis limits, ignoring None | |
min_value = min(filter(lambda x: x is not None, values), default=None) | |
max_value = max(filter(lambda x: x is not None, values), default=None) | |
all_min_values.append(min_value) | |
all_max_values.append(max_value) | |
category = "left_pupil" if "left" in category else "right_pupil" | |
# Create a DataFrame from the values | |
df = pd.DataFrame( | |
{ | |
"Diameter": values, | |
"Frame": range(1, len(values) + 1), # Create a frame column starting from 1 | |
"Category": category, # Add a column to specify the category | |
} | |
) | |
# Append to combined DataFrame | |
combined_df = pd.concat([combined_df, df], ignore_index=True) | |
combined_chart = ( | |
alt.Chart(combined_df) | |
.mark_line() | |
.encode( | |
x=alt.X("Frame:Q", title="Frame Number"), | |
y=alt.Y( | |
"Diameter:Q", | |
title="Diameter", | |
scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]), | |
), | |
color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"), | |
tooltip=["Frame", "Diameter:Q", "Category:N"], | |
) | |
) | |
points_chart = combined_chart.mark_point(filled=True) | |
final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive() | |
final_chart = final_chart.configure_axis(grid=True) | |
# Display the combined chart | |
st.altair_chart(final_chart, use_container_width=True) | |
# -------------------------------------------- | |
# Convert to a DataFrame | |
left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]] | |
right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]] | |
df = pd.DataFrame( | |
{ | |
"Frame": range(1, len(left_pupil_values) + 1), | |
"Left Pupil Diameter": left_pupil_values, | |
"Right Pupil Diameter": right_pupil_values, | |
} | |
) | |
# Calculate the difference between left and right pupil diameters | |
df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"] | |
# Determine the status of the difference | |
df["Difference Status"] = df.apply( | |
lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R", | |
axis=1, | |
) | |
return df | |
def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path): | |
output_video_path = f"{root_path}/tmp.webm" | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video( | |
cols, | |
video_frames, | |
tv_model, | |
pupil_selection, | |
output_video_path, | |
cam_method=CAM_METHODS[-1], | |
blink_detection=blink_detection, | |
) | |
os.remove(video_path) | |
num_columns = len(predicted_diameters) | |
cols = st.columns(num_columns) | |
if num_columns == 2: | |
df = plot_combined_charts(predicted_diameters) | |
else: | |
df = plot_individual_charts(predicted_diameters, cols) | |
if eyes_ratios is not None and len(eyes_ratios) > 0: | |
df = plot_ears(eyes_ratios, df) | |
st.dataframe(df, hide_index=True, use_container_width=True) | |