|
|
|
|
|
|
|
""" |
|
utils.py |
|
|
|
This module provides utility functions for various tasks such as setting random seeds, |
|
importing modules from files, managing checkpoint files, and saving video files from |
|
sequences of PIL images. |
|
|
|
Functions: |
|
seed_everything(seed) |
|
import_filename(filename) |
|
delete_additional_ckpt(base_path, num_keep) |
|
save_videos_from_pil(pil_images, path, fps=8) |
|
|
|
Dependencies: |
|
importlib |
|
os |
|
os.path as osp |
|
random |
|
shutil |
|
sys |
|
pathlib.Path |
|
av |
|
cv2 |
|
mediapipe as mp |
|
numpy as np |
|
torch |
|
torchvision |
|
einops.rearrange |
|
moviepy.editor.AudioFileClip, VideoClip |
|
PIL.Image |
|
|
|
Examples: |
|
seed_everything(42) |
|
imported_module = import_filename('path/to/your/module.py') |
|
delete_additional_ckpt('path/to/checkpoints', 1) |
|
save_videos_from_pil(pil_images, 'output/video.mp4', fps=12) |
|
|
|
The functions in this module ensure reproducibility of experiments by seeding random number |
|
generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones, |
|
and provide a way to save sequences of images as video files. |
|
|
|
Function Details: |
|
seed_everything(seed) |
|
Seeds all random number generators to ensure reproducibility. |
|
|
|
import_filename(filename) |
|
Imports a module from a given file location. |
|
|
|
delete_additional_ckpt(base_path, num_keep) |
|
Deletes additional checkpoint files in the given directory. |
|
|
|
save_videos_from_pil(pil_images, path, fps=8) |
|
Saves a sequence of images as a video using the Pillow library. |
|
|
|
Attributes: |
|
_ (str): Placeholder for static type checking |
|
""" |
|
|
|
import importlib |
|
import os |
|
import os.path as osp |
|
import random |
|
import shutil |
|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
|
|
import av |
|
import cv2 |
|
import mediapipe as mp |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from einops import rearrange |
|
from moviepy.editor import AudioFileClip, VideoClip |
|
from PIL import Image |
|
|
|
|
|
def seed_everything(seed): |
|
""" |
|
Seeds all random number generators to ensure reproducibility. |
|
|
|
Args: |
|
seed (int): The seed value to set for all random number generators. |
|
""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed % (2**32)) |
|
random.seed(seed) |
|
|
|
|
|
def import_filename(filename): |
|
""" |
|
Import a module from a given file location. |
|
|
|
Args: |
|
filename (str): The path to the file containing the module to be imported. |
|
|
|
Returns: |
|
module: The imported module. |
|
|
|
Raises: |
|
ImportError: If the module cannot be imported. |
|
|
|
Example: |
|
>>> imported_module = import_filename('path/to/your/module.py') |
|
""" |
|
spec = importlib.util.spec_from_file_location("mymodule", filename) |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[spec.name] = module |
|
spec.loader.exec_module(module) |
|
return module |
|
|
|
|
|
def delete_additional_ckpt(base_path, num_keep): |
|
""" |
|
Deletes additional checkpoint files in the given directory. |
|
|
|
Args: |
|
base_path (str): The path to the directory containing the checkpoint files. |
|
num_keep (int): The number of most recent checkpoint files to keep. |
|
|
|
Returns: |
|
None |
|
|
|
Raises: |
|
FileNotFoundError: If the base_path does not exist. |
|
|
|
Example: |
|
>>> delete_additional_ckpt('path/to/checkpoints', 1) |
|
# This will delete all but the most recent checkpoint file in 'path/to/checkpoints'. |
|
""" |
|
dirs = [] |
|
for d in os.listdir(base_path): |
|
if d.startswith("checkpoint-"): |
|
dirs.append(d) |
|
num_tot = len(dirs) |
|
if num_tot <= num_keep: |
|
return |
|
|
|
del_dirs = sorted(dirs, key=lambda x: int( |
|
x.split("-")[-1]))[: num_tot - num_keep] |
|
for d in del_dirs: |
|
path_to_dir = osp.join(base_path, d) |
|
if osp.exists(path_to_dir): |
|
shutil.rmtree(path_to_dir) |
|
|
|
|
|
def save_videos_from_pil(pil_images, path, fps=8): |
|
""" |
|
Save a sequence of images as a video using the Pillow library. |
|
|
|
Args: |
|
pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video. |
|
path (str): The output file path for the video. |
|
fps (int, optional): The frames per second rate of the video. Defaults to 8. |
|
|
|
Returns: |
|
None |
|
|
|
Raises: |
|
ValueError: If the save format is not supported. |
|
|
|
This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate. |
|
The output file format is determined by the file extension of the provided path. Supported formats include |
|
.mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video |
|
creation. |
|
""" |
|
save_fmt = Path(path).suffix |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
width, height = pil_images[0].size |
|
|
|
if save_fmt == ".mp4": |
|
codec = "libx264" |
|
container = av.open(path, "w") |
|
stream = container.add_stream(codec, rate=fps) |
|
|
|
stream.width = width |
|
stream.height = height |
|
|
|
for pil_image in pil_images: |
|
|
|
av_frame = av.VideoFrame.from_image(pil_image) |
|
container.mux(stream.encode(av_frame)) |
|
container.mux(stream.encode()) |
|
container.close() |
|
|
|
elif save_fmt == ".gif": |
|
pil_images[0].save( |
|
fp=path, |
|
format="GIF", |
|
append_images=pil_images[1:], |
|
save_all=True, |
|
duration=(1 / fps * 1000), |
|
loop=0, |
|
) |
|
else: |
|
raise ValueError("Unsupported file type. Use .mp4 or .gif.") |
|
|
|
|
|
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): |
|
""" |
|
Save a grid of videos as an animation or video. |
|
|
|
Args: |
|
videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width) |
|
containing the videos to save. |
|
path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif. |
|
rescale (bool, optional): If True, rescale the video to the original resolution. |
|
Defaults to False. |
|
n_rows (int, optional): The number of rows in the video grid. Defaults to 6. |
|
fps (int, optional): The frame rate of the saved video. Defaults to 8. |
|
|
|
Raises: |
|
ValueError: If the video format is not supported. |
|
|
|
Returns: |
|
None |
|
""" |
|
videos = rearrange(videos, "b c t h w -> t b c h w") |
|
|
|
outputs = [] |
|
|
|
for x in videos: |
|
x = torchvision.utils.make_grid(x, nrow=n_rows) |
|
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
x = (x * 255).numpy().astype(np.uint8) |
|
x = Image.fromarray(x) |
|
|
|
outputs.append(x) |
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
|
save_videos_from_pil(outputs, path, fps) |
|
|
|
|
|
def read_frames(video_path): |
|
""" |
|
Reads video frames from a given video file. |
|
|
|
Args: |
|
video_path (str): The path to the video file. |
|
|
|
Returns: |
|
container (av.container.InputContainer): The input container object |
|
containing the video stream. |
|
|
|
Raises: |
|
FileNotFoundError: If the video file is not found. |
|
RuntimeError: If there is an error in reading the video stream. |
|
|
|
The function reads the video frames from the specified video file using the |
|
Python AV library (av). It returns an input container object that contains |
|
the video stream. If the video file is not found, it raises a FileNotFoundError, |
|
and if there is an error in reading the video stream, it raises a RuntimeError. |
|
""" |
|
container = av.open(video_path) |
|
|
|
video_stream = next(s for s in container.streams if s.type == "video") |
|
frames = [] |
|
for packet in container.demux(video_stream): |
|
for frame in packet.decode(): |
|
image = Image.frombytes( |
|
"RGB", |
|
(frame.width, frame.height), |
|
frame.to_rgb().to_ndarray(), |
|
) |
|
frames.append(image) |
|
|
|
return frames |
|
|
|
|
|
def get_fps(video_path): |
|
""" |
|
Get the frame rate (FPS) of a video file. |
|
|
|
Args: |
|
video_path (str): The path to the video file. |
|
|
|
Returns: |
|
int: The frame rate (FPS) of the video file. |
|
""" |
|
container = av.open(video_path) |
|
video_stream = next(s for s in container.streams if s.type == "video") |
|
fps = video_stream.average_rate |
|
container.close() |
|
return fps |
|
|
|
|
|
def tensor_to_video(tensor, output_video_file, audio_source, fps=25): |
|
""" |
|
Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. |
|
|
|
Args: |
|
tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. |
|
output_video_file (str): The file path where the output video will be saved. |
|
audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added. |
|
fps (int): The frame rate of the output video. Default is 25 fps. |
|
""" |
|
tensor = tensor.permute(1, 2, 3, 0).cpu( |
|
).numpy() |
|
tensor = np.clip(tensor * 255, 0, 255).astype( |
|
np.uint8 |
|
) |
|
|
|
def make_frame(t): |
|
|
|
frame_index = min(int(t * fps), tensor.shape[0] - 1) |
|
return tensor[frame_index] |
|
new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) |
|
audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps) |
|
new_video_clip = new_video_clip.set_audio(audio_clip) |
|
new_video_clip.write_videofile(output_video_file, fps=fps) |
|
|
|
|
|
silhouette_ids = [ |
|
10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288, |
|
397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136, |
|
172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109 |
|
] |
|
lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291, |
|
146, 91, 181, 84, 17, 314, 405, 321, 375] |
|
|
|
|
|
def compute_face_landmarks(detection_result, h, w): |
|
""" |
|
Compute face landmarks from a detection result. |
|
|
|
Args: |
|
detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks. |
|
h (int): The height of the video frame. |
|
w (int): The width of the video frame. |
|
|
|
Returns: |
|
face_landmarks_list (list): A list of face landmarks. |
|
""" |
|
face_landmarks_list = detection_result.face_landmarks |
|
if len(face_landmarks_list) != 1: |
|
print("#face is invalid:", len(face_landmarks_list)) |
|
return [] |
|
return [[p.x * w, p.y * h] for p in face_landmarks_list[0]] |
|
|
|
|
|
def get_landmark(file): |
|
""" |
|
This function takes a file as input and returns the facial landmarks detected in the file. |
|
|
|
Args: |
|
file (str): The path to the file containing the video or image to be processed. |
|
|
|
Returns: |
|
Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks. |
|
""" |
|
model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" |
|
BaseOptions = mp.tasks.BaseOptions |
|
FaceLandmarker = mp.tasks.vision.FaceLandmarker |
|
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions |
|
VisionRunningMode = mp.tasks.vision.RunningMode |
|
|
|
options = FaceLandmarkerOptions( |
|
base_options=BaseOptions(model_asset_path=model_path), |
|
running_mode=VisionRunningMode.IMAGE, |
|
) |
|
|
|
with FaceLandmarker.create_from_options(options) as landmarker: |
|
image = mp.Image.create_from_file(str(file)) |
|
height, width = image.height, image.width |
|
face_landmarker_result = landmarker.detect(image) |
|
face_landmark = compute_face_landmarks( |
|
face_landmarker_result, height, width) |
|
|
|
return np.array(face_landmark), height, width |
|
|
|
|
|
def get_lip_mask(landmarks, height, width, out_path): |
|
""" |
|
Extracts the lip region from the given landmarks and saves it as an image. |
|
|
|
Parameters: |
|
landmarks (numpy.ndarray): Array of facial landmarks. |
|
height (int): Height of the output lip mask image. |
|
width (int): Width of the output lip mask image. |
|
out_path (pathlib.Path): Path to save the lip mask image. |
|
""" |
|
lip_landmarks = np.take(landmarks, lip_ids, 0) |
|
min_xy_lip = np.round(np.min(lip_landmarks, 0)) |
|
max_xy_lip = np.round(np.max(lip_landmarks, 0)) |
|
min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region( |
|
[min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, 2.0) |
|
lip_mask = np.zeros((height, width), dtype=np.uint8) |
|
lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]), |
|
round(min_xy_lip[0]):round(max_xy_lip[0])] = 255 |
|
cv2.imwrite(str(out_path), lip_mask) |
|
|
|
|
|
def get_face_mask(landmarks, height, width, out_path, expand_ratio): |
|
""" |
|
Generate a face mask based on the given landmarks. |
|
|
|
Args: |
|
landmarks (numpy.ndarray): The landmarks of the face. |
|
height (int): The height of the output face mask image. |
|
width (int): The width of the output face mask image. |
|
out_path (pathlib.Path): The path to save the face mask image. |
|
|
|
Returns: |
|
None. The face mask image is saved at the specified path. |
|
""" |
|
face_landmarks = np.take(landmarks, silhouette_ids, 0) |
|
min_xy_face = np.round(np.min(face_landmarks, 0)) |
|
max_xy_face = np.round(np.max(face_landmarks, 0)) |
|
min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region( |
|
[min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio) |
|
face_mask = np.zeros((height, width), dtype=np.uint8) |
|
face_mask[round(min_xy_face[1]):round(max_xy_face[1]), |
|
round(min_xy_face[0]):round(max_xy_face[0])] = 255 |
|
cv2.imwrite(str(out_path), face_mask) |
|
|
|
|
|
def get_mask(file, cache_dir, face_expand_raio): |
|
""" |
|
Generate a face mask based on the given landmarks and save it to the specified cache directory. |
|
|
|
Args: |
|
file (str): The path to the file containing the landmarks. |
|
cache_dir (str): The directory to save the generated face mask. |
|
|
|
Returns: |
|
None |
|
""" |
|
landmarks, height, width = get_landmark(file) |
|
file_name = os.path.basename(file).split(".")[0] |
|
get_lip_mask(landmarks, height, width, os.path.join( |
|
cache_dir, f"{file_name}_lip_mask.png")) |
|
get_face_mask(landmarks, height, width, os.path.join( |
|
cache_dir, f"{file_name}_face_mask.png"), face_expand_raio) |
|
get_blur_mask(os.path.join( |
|
cache_dir, f"{file_name}_face_mask.png"), os.path.join( |
|
cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51)) |
|
get_blur_mask(os.path.join( |
|
cache_dir, f"{file_name}_lip_mask.png"), os.path.join( |
|
cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31)) |
|
get_background_mask(os.path.join( |
|
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( |
|
cache_dir, f"{file_name}_sep_background.png")) |
|
get_sep_face_mask(os.path.join( |
|
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( |
|
cache_dir, f"{file_name}_sep_lip.png"), os.path.join( |
|
cache_dir, f"{file_name}_sep_face.png")) |
|
|
|
|
|
def expand_region(region, image_w, image_h, expand_ratio=1.0): |
|
""" |
|
Expand the given region by a specified ratio. |
|
Args: |
|
region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region. |
|
image_w (int): The width of the image. |
|
image_h (int): The height of the image. |
|
expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0. |
|
|
|
Returns: |
|
tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region. |
|
""" |
|
|
|
min_x, max_x, min_y, max_y = region |
|
mid_x = (max_x + min_x) // 2 |
|
side_len_x = (max_x - min_x) * expand_ratio |
|
mid_y = (max_y + min_y) // 2 |
|
side_len_y = (max_y - min_y) * expand_ratio |
|
min_x = mid_x - side_len_x // 2 |
|
max_x = mid_x + side_len_x // 2 |
|
min_y = mid_y - side_len_y // 2 |
|
max_y = mid_y + side_len_y // 2 |
|
if min_x < 0: |
|
max_x -= min_x |
|
min_x = 0 |
|
if max_x > image_w: |
|
min_x -= max_x - image_w |
|
max_x = image_w |
|
if min_y < 0: |
|
max_y -= min_y |
|
min_y = 0 |
|
if max_y > image_h: |
|
min_y -= max_y - image_h |
|
max_y = image_h |
|
|
|
return round(min_x), round(max_x), round(min_y), round(max_y) |
|
|
|
|
|
def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)): |
|
""" |
|
Read, resize, blur, normalize, and save an image. |
|
|
|
Parameters: |
|
file_path (str): Path to the input image file. |
|
output_dir (str): Path to the output directory to save blurred images. |
|
resize_dim (tuple): Dimensions to resize the images to. |
|
kernel_size (tuple): Size of the kernel to use for Gaussian blur. |
|
""" |
|
|
|
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) |
|
|
|
|
|
if mask is not None: |
|
|
|
resized_mask = cv2.resize(mask, resize_dim) |
|
|
|
blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0) |
|
|
|
normalized_mask = cv2.normalize( |
|
blurred_mask, None, 0, 255, cv2.NORM_MINMAX) |
|
|
|
cv2.imwrite(output_file_path, normalized_mask) |
|
return f"Processed, normalized, and saved: {output_file_path}" |
|
return f"Failed to load image: {file_path}" |
|
|
|
|
|
def get_background_mask(file_path, output_file_path): |
|
""" |
|
Read an image, invert its values, and save the result. |
|
|
|
Parameters: |
|
file_path (str): Path to the input image file. |
|
output_dir (str): Path to the output directory to save the inverted image. |
|
""" |
|
|
|
image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) |
|
|
|
if image is None: |
|
print(f"Failed to load image: {file_path}") |
|
return |
|
|
|
|
|
inverted_image = 1.0 - ( |
|
image / 255.0 |
|
) |
|
|
|
inverted_image = (inverted_image * 255).astype(np.uint8) |
|
|
|
|
|
cv2.imwrite(output_file_path, inverted_image) |
|
print(f"Processed and saved: {output_file_path}") |
|
|
|
|
|
def get_sep_face_mask(file_path1, file_path2, output_file_path): |
|
""" |
|
Read two images, subtract the second one from the first, and save the result. |
|
|
|
Parameters: |
|
output_dir (str): Path to the output directory to save the subtracted image. |
|
""" |
|
|
|
|
|
mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE) |
|
mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE) |
|
|
|
if mask1 is None or mask2 is None: |
|
print(f"Failed to load images: {file_path1}") |
|
return |
|
|
|
|
|
if mask1.shape != mask2.shape: |
|
print( |
|
f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}" |
|
) |
|
return |
|
|
|
|
|
result_mask = cv2.subtract(mask1, mask2) |
|
|
|
|
|
cv2.imwrite(output_file_path, result_mask) |
|
print(f"Processed and saved: {output_file_path}") |
|
|
|
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): |
|
p = subprocess.Popen([ |
|
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file |
|
]) |
|
ret = p.wait() |
|
assert ret == 0, "Resample audio failed!" |
|
return output_audio_file |
|
|
|
def get_face_region(image_path: str, detector): |
|
try: |
|
image = cv2.imread(image_path) |
|
if image is None: |
|
print(f"Failed to open image: {image_path}. Skipping...") |
|
return None, None |
|
|
|
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) |
|
detection_result = detector.detect(mp_image) |
|
|
|
|
|
mask = np.zeros_like(image, dtype=np.uint8) |
|
|
|
for detection in detection_result.detections: |
|
bbox = detection.bounding_box |
|
start_point = (int(bbox.origin_x), int(bbox.origin_y)) |
|
end_point = (int(bbox.origin_x + bbox.width), |
|
int(bbox.origin_y + bbox.height)) |
|
cv2.rectangle(mask, start_point, end_point, |
|
(255, 255, 255), thickness=-1) |
|
|
|
save_path = image_path.replace("images", "face_masks") |
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
cv2.imwrite(save_path, mask) |
|
|
|
return image_path, mask |
|
except Exception as e: |
|
print(f"Error processing image {image_path}: {e}") |
|
return None, None |
|
|