Spaces:
Running
Running
import base64 | |
from io import BytesIO | |
import os | |
import sys | |
import cv2 | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import streamlit as st | |
import torch | |
import tempfile | |
from PIL import Image | |
from torchvision.transforms.functional import to_pil_image | |
from torchvision import transforms | |
from torchcam.methods import CAM | |
from torchcam import methods as torchcam_methods | |
from torchcam.utils import overlay_mask | |
import os.path as osp | |
root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
sys.path.append(root_path) | |
from preprocessing.dataset_creation import EyeDentityDatasetCreation | |
from utils import get_model | |
def load_model(model_configs, device="cpu"): | |
"""Loads the pre-trained model.""" | |
model_path = os.path.join(root_path, model_configs["model_path"]) | |
model_dict = torch.load(model_path, map_location=device) | |
model = get_model(model_configs=model_configs) | |
model.load_state_dict(model_dict) | |
model = model.to(device).eval() | |
return model | |
def extract_frames(video_path): | |
"""Extracts frames from a video file.""" | |
vidcap = cv2.VideoCapture(video_path) | |
frames = [] | |
success, image = vidcap.read() | |
while success: | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
frames.append(image_rgb) | |
success, image = vidcap.read() | |
vidcap.release() | |
return frames | |
def resize_frame(image, max_width=640, max_height=480): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
original_size = image.size | |
# Resize the frame similarly to the image resizing logic | |
if original_size[0] == original_size[1] and original_size[0] >= 256: | |
max_size = (256, 256) | |
else: | |
max_size = list(original_size) | |
if original_size[0] >= max_width: | |
max_size[0] = max_width | |
elif original_size[0] < 64: | |
max_size[0] = 64 | |
if original_size[1] >= max_height: | |
max_size[1] = max_height | |
elif original_size[1] < 32: | |
max_size[1] = 32 | |
image.thumbnail(max_size) | |
# image = image.resize(max_size) | |
return image | |
def is_image(file_extension): | |
"""Checks if the file is an image.""" | |
return file_extension.lower() in ["png", "jpeg", "jpg"] | |
def is_video(file_extension): | |
"""Checks if the file is a video.""" | |
return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"] | |
def get_codec_and_extension(file_format): | |
"""Return codec and file extension based on the format.""" | |
if file_format == "mp4": | |
return "H264", ".mp4" | |
elif file_format == "avi": | |
return "MJPG", ".avi" | |
elif file_format == "webm": | |
return "VP80", ".webm" | |
else: | |
return "MJPG", ".avi" | |
def display_results(input_image, cam_frame, pupil_diameter, cols): | |
"""Displays the input image and overlayed CAM result.""" | |
fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
axs[0].imshow(input_image) | |
axs[0].axis("off") | |
axs[0].set_title("Input Image") | |
axs[1].imshow(cam_frame) | |
axs[1].axis("off") | |
axs[1].set_title("Overlayed CAM") | |
cols[-1].pyplot(fig) | |
cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm") | |
def preprocess_image(input_img, max_size=(256, 256)): | |
"""Resizes and preprocesses an image.""" | |
input_img.thumbnail(max_size) | |
preprocess_steps = [ | |
transforms.ToTensor(), | |
transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), | |
] | |
return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0) | |
def overlay_text_on_frame(frame, text, position=(16, 20)): | |
"""Write text on the image frame using OpenCV.""" | |
return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA) | |
def get_configs(blink_detection=False): | |
upscale = "-" | |
upscale_method_or_model = "-" | |
if upscale == "-": | |
sr_configs = None | |
else: | |
sr_configs = { | |
"method": upscale_method_or_model, | |
"params": {"upscale": upscale}, | |
} | |
config_file = { | |
"sr_configs": sr_configs, | |
"feature_extraction_configs": { | |
"blink_detection": blink_detection, | |
"upscale": upscale, | |
"extraction_library": "mediapipe", | |
}, | |
} | |
return config_file | |
def setup(cols, pupil_selection, tv_model, output_path): | |
left_pupil_model = None | |
left_pupil_cam_extractor = None | |
right_pupil_model = None | |
right_pupil_cam_extractor = None | |
output_frames = {} | |
input_frames = {} | |
predicted_diameters = {} | |
pred_diameters_frames = {} | |
if pupil_selection == "both": | |
selected_eyes = ["left_eye", "right_eye"] | |
elif pupil_selection == "left_pupil": | |
selected_eyes = ["left_eye"] | |
elif pupil_selection == "right_pupil": | |
selected_eyes = ["right_eye"] | |
for i, eye_type in enumerate(selected_eyes): | |
model_configs = { | |
"model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", | |
"registered_model_name": tv_model, | |
"num_classes": 1, | |
} | |
if eye_type == "left_eye": | |
left_pupil_model = load_model(model_configs) | |
left_pupil_cam_extractor = None | |
output_frames[eye_type] = [] | |
input_frames[eye_type] = [] | |
predicted_diameters[eye_type] = [] | |
pred_diameters_frames[eye_type] = [] | |
else: | |
right_pupil_model = load_model(model_configs) | |
right_pupil_cam_extractor = None | |
output_frames[eye_type] = [] | |
input_frames[eye_type] = [] | |
predicted_diameters[eye_type] = [] | |
pred_diameters_frames[eye_type] = [] | |
video_placeholders = {} | |
if output_path: | |
video_cols = cols[1].columns(len(input_frames.keys())) | |
for i, eye_type in enumerate(list(input_frames.keys())): | |
video_placeholders[eye_type] = video_cols[i].empty() | |
return ( | |
selected_eyes, | |
input_frames, | |
output_frames, | |
predicted_diameters, | |
pred_diameters_frames, | |
video_placeholders, | |
left_pupil_model, | |
left_pupil_cam_extractor, | |
right_pupil_model, | |
right_pupil_cam_extractor, | |
) | |
def process_frames( | |
cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False | |
): | |
config_file = get_configs(blink_detection) | |
face_frames = [] | |
( | |
selected_eyes, | |
input_frames, | |
output_frames, | |
predicted_diameters, | |
pred_diameters_frames, | |
video_placeholders, | |
left_pupil_model, | |
left_pupil_cam_extractor, | |
right_pupil_model, | |
right_pupil_cam_extractor, | |
) = setup(cols, pupil_selection, tv_model, output_path) | |
ds_creation = EyeDentityDatasetCreation( | |
feature_extraction_configs=config_file["feature_extraction_configs"], | |
sr_configs=config_file["sr_configs"], | |
) | |
preprocess_steps = [ | |
transforms.ToTensor(), | |
transforms.Resize( | |
[32, 64], | |
interpolation=transforms.InterpolationMode.BICUBIC, | |
antialias=True, | |
), | |
] | |
preprocess_function = transforms.Compose(preprocess_steps) | |
eyes_ratios = [] | |
for idx, input_img in enumerate(input_imgs): | |
img = np.array(input_img) | |
ds_results = ds_creation(img) | |
left_eye = None | |
right_eye = None | |
blinked = False | |
eyes_ratio = None | |
if ds_results is not None and "face" in ds_results: | |
face_img = to_pil_image(ds_results["face"]) | |
has_face = True | |
else: | |
face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8)) | |
has_face = False | |
face_frames.append({"has_face": has_face, "img": face_img}) | |
if ds_results is not None and "eyes" in ds_results.keys(): | |
blinked = ds_results["eyes"]["blinked"] | |
eyes_ratio = ds_results["eyes"]["eyes_ratio"] | |
if eyes_ratio is not None: | |
eyes_ratios.append(eyes_ratio) | |
if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None: | |
left_eye = ds_results["eyes"]["left_eye"] | |
left_eye = to_pil_image(left_eye).convert("RGB") | |
left_eye = preprocess_function(left_eye) | |
left_eye = left_eye.unsqueeze(0) | |
if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None: | |
right_eye = ds_results["eyes"]["right_eye"] | |
right_eye = to_pil_image(right_eye).convert("RGB") | |
right_eye = preprocess_function(right_eye) | |
right_eye = right_eye.unsqueeze(0) | |
else: | |
input_img = preprocess_function(input_img) | |
input_img = input_img.unsqueeze(0) | |
if pupil_selection == "left_pupil": | |
left_eye = input_img | |
elif pupil_selection == "right_pupil": | |
right_eye = input_img | |
else: | |
left_eye = input_img | |
right_eye = input_img | |
for i, eye_type in enumerate(selected_eyes): | |
if blinked: | |
if left_eye is not None and eye_type == "left_eye": | |
_, height, width = left_eye.squeeze(0).shape | |
input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
elif right_eye is not None and eye_type == "right_eye": | |
_, height, width = right_eye.squeeze(0).shape | |
input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
input_img_np = np.array(input_image_pil) | |
zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8)) | |
output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink") | |
predicted_diameter = "blink" | |
else: | |
if left_eye is not None and eye_type == "left_eye": | |
if left_pupil_cam_extractor is None: | |
if tv_model == "ResNet18": | |
target_layer = left_pupil_model.resnet.layer4[-1].conv2 | |
elif tv_model == "ResNet50": | |
target_layer = left_pupil_model.resnet.layer4[-1].conv3 | |
else: | |
raise Exception(f"No target layer available for selected model: {tv_model}") | |
left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
left_pupil_model, | |
target_layer=target_layer, | |
fc_layer=left_pupil_model.resnet.fc, | |
input_shape=left_eye.shape, | |
) | |
output = left_pupil_model(left_eye) | |
predicted_diameter = output[0].item() | |
act_maps = left_pupil_cam_extractor(0, output) | |
activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps) | |
input_image_pil = to_pil_image(left_eye.squeeze(0)) | |
elif right_eye is not None and eye_type == "right_eye": | |
if right_pupil_cam_extractor is None: | |
if tv_model == "ResNet18": | |
target_layer = right_pupil_model.resnet.layer4[-1].conv2 | |
elif tv_model == "ResNet50": | |
target_layer = right_pupil_model.resnet.layer4[-1].conv3 | |
else: | |
raise Exception(f"No target layer available for selected model: {tv_model}") | |
right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method]( | |
right_pupil_model, | |
target_layer=target_layer, | |
fc_layer=right_pupil_model.resnet.fc, | |
input_shape=right_eye.shape, | |
) | |
output = right_pupil_model(right_eye) | |
predicted_diameter = output[0].item() | |
act_maps = right_pupil_cam_extractor(0, output) | |
activation_map = ( | |
act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps) | |
) | |
input_image_pil = to_pil_image(right_eye.squeeze(0)) | |
# Create CAM overlay | |
activation_map_pil = to_pil_image(activation_map, mode="F") | |
result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5) | |
input_img_np = np.array(input_image_pil) | |
output_img_np = np.array(result) | |
# Add frame and predicted diameter to lists | |
input_frames[eye_type].append(input_img_np) | |
output_frames[eye_type].append(output_img_np) | |
predicted_diameters[eye_type].append(predicted_diameter) | |
if output_path: | |
height, width, _ = output_img_np.shape | |
frame = np.zeros((height, width, 3), dtype=np.uint8) | |
if not isinstance(predicted_diameter, str): | |
text = f"{predicted_diameter:.2f}" | |
else: | |
text = predicted_diameter | |
frame = overlay_text_on_frame(frame, text) | |
pred_diameters_frames[eye_type].append(frame) | |
combined_frame = np.vstack((input_img_np, output_img_np, frame)) | |
video_placeholders[eye_type].image(combined_frame, use_column_width=True) | |
st.session_state.current_frame = idx + 1 | |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
if output_path: | |
combine_and_show_frames( | |
input_frames, output_frames, pred_diameters_frames, output_path, codec, video_placeholders | |
) | |
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
# Function to display video with autoplay and loop | |
def display_video_with_autoplay(video_col, video_path): | |
video_html = f""" | |
<video width="100%" height="auto" autoplay loop muted> | |
<source src="data:video/mp4;base64,{video_path}" type="video/mp4"> | |
</video> | |
""" | |
video_col.markdown(video_html, unsafe_allow_html=True) | |
def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False): | |
resized_frames = [] | |
for i, frame in enumerate(video_frames): | |
input_img = resize_frame(frame, max_width=640, max_height=480) | |
resized_frames.append(input_img) | |
file_format = output_path.split(".")[-1] | |
codec, extension = get_codec_and_extension(file_format) | |
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames( | |
cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection | |
) | |
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios | |
# Function to convert string values to float or None | |
def convert_diameter(value): | |
try: | |
return float(value) | |
except (ValueError, TypeError): | |
return None # Return None if conversion fails | |
def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, output_path, codec, video_cols): | |
# Assuming all frames have the same keys (eye types) | |
eye_types = input_frames.keys() | |
for i, eye_type in enumerate(eye_types): | |
in_frames = input_frames[eye_type] | |
cam_out_frames = cam_frames[eye_type] | |
pred_diameters_text_frames = pred_diameters_frames[eye_type] | |
# Get frame properties (assuming all frames have the same dimensions) | |
height, width, _ = in_frames[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*codec) | |
fps = 10.0 | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height * 3)) # Width is tripled for concatenation | |
# Loop through each set of frames and concatenate them | |
for j in range(len(in_frames)): | |
input_frame = in_frames[j] | |
cam_frame = cam_out_frames[j] | |
pred_frame = pred_diameters_text_frames[j] | |
# Convert frames to BGR if necessary | |
input_frame_bgr = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR) | |
cam_frame_bgr = cv2.cvtColor(cam_frame, cv2.COLOR_RGB2BGR) | |
pred_frame_bgr = cv2.cvtColor(pred_frame, cv2.COLOR_RGB2BGR) | |
# Concatenate frames horizontally (input, cam, pred) | |
combined_frame = np.vstack((input_frame_bgr, cam_frame_bgr, pred_frame_bgr)) | |
# Write the combined frame to the video | |
out.write(combined_frame) | |
# Release the video writer | |
out.release() | |
# Read the video and encode it in base64 for displaying | |
with open(output_path, "rb") as video_file: | |
video_bytes = video_file.read() | |
video_base64 = base64.b64encode(video_bytes).decode("utf-8") | |
# Display the combined video | |
display_video_with_autoplay(video_cols[eye_type], video_base64) | |
# Clean up | |
os.remove(output_path) | |