Spaces:
Running
on
A10G
Running
on
A10G
# pylint: disable=C0116 | |
# pylint: disable=W0718 | |
# pylint: disable=R1732 | |
""" | |
utils.py | |
This module provides utility functions for various tasks such as setting random seeds, | |
importing modules from files, managing checkpoint files, and saving video files from | |
sequences of PIL images. | |
Functions: | |
seed_everything(seed) | |
import_filename(filename) | |
delete_additional_ckpt(base_path, num_keep) | |
save_videos_from_pil(pil_images, path, fps=8) | |
Dependencies: | |
importlib | |
os | |
os.path as osp | |
random | |
shutil | |
sys | |
pathlib.Path | |
av | |
cv2 | |
mediapipe as mp | |
numpy as np | |
torch | |
torchvision | |
einops.rearrange | |
moviepy.editor.AudioFileClip, VideoClip | |
PIL.Image | |
Examples: | |
seed_everything(42) | |
imported_module = import_filename('path/to/your/module.py') | |
delete_additional_ckpt('path/to/checkpoints', 1) | |
save_videos_from_pil(pil_images, 'output/video.mp4', fps=12) | |
The functions in this module ensure reproducibility of experiments by seeding random number | |
generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones, | |
and provide a way to save sequences of images as video files. | |
Function Details: | |
seed_everything(seed) | |
Seeds all random number generators to ensure reproducibility. | |
import_filename(filename) | |
Imports a module from a given file location. | |
delete_additional_ckpt(base_path, num_keep) | |
Deletes additional checkpoint files in the given directory. | |
save_videos_from_pil(pil_images, path, fps=8) | |
Saves a sequence of images as a video using the Pillow library. | |
Attributes: | |
_ (str): Placeholder for static type checking | |
""" | |
import importlib | |
import os | |
import os.path as osp | |
import random | |
import shutil | |
import subprocess | |
import sys | |
from pathlib import Path | |
import av | |
import cv2 | |
import mediapipe as mp | |
import numpy as np | |
import torch | |
import torchvision | |
from einops import rearrange | |
from moviepy.editor import AudioFileClip, VideoClip | |
from PIL import Image | |
def seed_everything(seed): | |
""" | |
Seeds all random number generators to ensure reproducibility. | |
Args: | |
seed (int): The seed value to set for all random number generators. | |
""" | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed % (2**32)) | |
random.seed(seed) | |
def import_filename(filename): | |
""" | |
Import a module from a given file location. | |
Args: | |
filename (str): The path to the file containing the module to be imported. | |
Returns: | |
module: The imported module. | |
Raises: | |
ImportError: If the module cannot be imported. | |
Example: | |
>>> imported_module = import_filename('path/to/your/module.py') | |
""" | |
spec = importlib.util.spec_from_file_location("mymodule", filename) | |
module = importlib.util.module_from_spec(spec) | |
sys.modules[spec.name] = module | |
spec.loader.exec_module(module) | |
return module | |
def delete_additional_ckpt(base_path, num_keep): | |
""" | |
Deletes additional checkpoint files in the given directory. | |
Args: | |
base_path (str): The path to the directory containing the checkpoint files. | |
num_keep (int): The number of most recent checkpoint files to keep. | |
Returns: | |
None | |
Raises: | |
FileNotFoundError: If the base_path does not exist. | |
Example: | |
>>> delete_additional_ckpt('path/to/checkpoints', 1) | |
# This will delete all but the most recent checkpoint file in 'path/to/checkpoints'. | |
""" | |
dirs = [] | |
for d in os.listdir(base_path): | |
if d.startswith("checkpoint-"): | |
dirs.append(d) | |
num_tot = len(dirs) | |
if num_tot <= num_keep: | |
return | |
# ensure ckpt is sorted and delete the ealier! | |
del_dirs = sorted(dirs, key=lambda x: int( | |
x.split("-")[-1]))[: num_tot - num_keep] | |
for d in del_dirs: | |
path_to_dir = osp.join(base_path, d) | |
if osp.exists(path_to_dir): | |
shutil.rmtree(path_to_dir) | |
def save_videos_from_pil(pil_images, path, fps=8): | |
""" | |
Save a sequence of images as a video using the Pillow library. | |
Args: | |
pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video. | |
path (str): The output file path for the video. | |
fps (int, optional): The frames per second rate of the video. Defaults to 8. | |
Returns: | |
None | |
Raises: | |
ValueError: If the save format is not supported. | |
This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate. | |
The output file format is determined by the file extension of the provided path. Supported formats include | |
.mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video | |
creation. | |
""" | |
save_fmt = Path(path).suffix | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
width, height = pil_images[0].size | |
if save_fmt == ".mp4": | |
codec = "libx264" | |
container = av.open(path, "w") | |
stream = container.add_stream(codec, rate=fps) | |
stream.width = width | |
stream.height = height | |
for pil_image in pil_images: | |
# pil_image = Image.fromarray(image_arr).convert("RGB") | |
av_frame = av.VideoFrame.from_image(pil_image) | |
container.mux(stream.encode(av_frame)) | |
container.mux(stream.encode()) | |
container.close() | |
elif save_fmt == ".gif": | |
pil_images[0].save( | |
fp=path, | |
format="GIF", | |
append_images=pil_images[1:], | |
save_all=True, | |
duration=(1 / fps * 1000), | |
loop=0, | |
) | |
else: | |
raise ValueError("Unsupported file type. Use .mp4 or .gif.") | |
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
""" | |
Save a grid of videos as an animation or video. | |
Args: | |
videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width) | |
containing the videos to save. | |
path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif. | |
rescale (bool, optional): If True, rescale the video to the original resolution. | |
Defaults to False. | |
n_rows (int, optional): The number of rows in the video grid. Defaults to 6. | |
fps (int, optional): The frame rate of the saved video. Defaults to 8. | |
Raises: | |
ValueError: If the video format is not supported. | |
Returns: | |
None | |
""" | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
# height, width = videos.shape[-2:] | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
x = Image.fromarray(x) | |
outputs.append(x) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
save_videos_from_pil(outputs, path, fps) | |
def read_frames(video_path): | |
""" | |
Reads video frames from a given video file. | |
Args: | |
video_path (str): The path to the video file. | |
Returns: | |
container (av.container.InputContainer): The input container object | |
containing the video stream. | |
Raises: | |
FileNotFoundError: If the video file is not found. | |
RuntimeError: If there is an error in reading the video stream. | |
The function reads the video frames from the specified video file using the | |
Python AV library (av). It returns an input container object that contains | |
the video stream. If the video file is not found, it raises a FileNotFoundError, | |
and if there is an error in reading the video stream, it raises a RuntimeError. | |
""" | |
container = av.open(video_path) | |
video_stream = next(s for s in container.streams if s.type == "video") | |
frames = [] | |
for packet in container.demux(video_stream): | |
for frame in packet.decode(): | |
image = Image.frombytes( | |
"RGB", | |
(frame.width, frame.height), | |
frame.to_rgb().to_ndarray(), | |
) | |
frames.append(image) | |
return frames | |
def get_fps(video_path): | |
""" | |
Get the frame rate (FPS) of a video file. | |
Args: | |
video_path (str): The path to the video file. | |
Returns: | |
int: The frame rate (FPS) of the video file. | |
""" | |
container = av.open(video_path) | |
video_stream = next(s for s in container.streams if s.type == "video") | |
fps = video_stream.average_rate | |
container.close() | |
return fps | |
def tensor_to_video(tensor, output_video_file, audio_source, fps=25): | |
""" | |
Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. | |
Args: | |
tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. | |
output_video_file (str): The file path where the output video will be saved. | |
audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added. | |
fps (int): The frame rate of the output video. Default is 25 fps. | |
""" | |
tensor = tensor.permute(1, 2, 3, 0).cpu( | |
).numpy() # convert to [f, h, w, c] | |
tensor = np.clip(tensor * 255, 0, 255).astype( | |
np.uint8 | |
) # to [0, 255] | |
def make_frame(t): | |
# get index | |
frame_index = min(int(t * fps), tensor.shape[0] - 1) | |
return tensor[frame_index] | |
new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) | |
audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps) | |
new_video_clip = new_video_clip.set_audio(audio_clip) | |
new_video_clip.write_videofile(output_video_file, fps=fps) | |
silhouette_ids = [ | |
10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288, | |
397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136, | |
172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109 | |
] | |
lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291, | |
146, 91, 181, 84, 17, 314, 405, 321, 375] | |
def compute_face_landmarks(detection_result, h, w): | |
""" | |
Compute face landmarks from a detection result. | |
Args: | |
detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks. | |
h (int): The height of the video frame. | |
w (int): The width of the video frame. | |
Returns: | |
face_landmarks_list (list): A list of face landmarks. | |
""" | |
face_landmarks_list = detection_result.face_landmarks | |
if len(face_landmarks_list) != 1: | |
print("#face is invalid:", len(face_landmarks_list)) | |
return [] | |
return [[p.x * w, p.y * h] for p in face_landmarks_list[0]] | |
def get_landmark(file): | |
""" | |
This function takes a file as input and returns the facial landmarks detected in the file. | |
Args: | |
file (str): The path to the file containing the video or image to be processed. | |
Returns: | |
Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks. | |
""" | |
model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" | |
BaseOptions = mp.tasks.BaseOptions | |
FaceLandmarker = mp.tasks.vision.FaceLandmarker | |
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions | |
VisionRunningMode = mp.tasks.vision.RunningMode | |
# Create a face landmarker instance with the video mode: | |
options = FaceLandmarkerOptions( | |
base_options=BaseOptions(model_asset_path=model_path), | |
running_mode=VisionRunningMode.IMAGE, | |
) | |
with FaceLandmarker.create_from_options(options) as landmarker: | |
image = mp.Image.create_from_file(str(file)) | |
height, width = image.height, image.width | |
face_landmarker_result = landmarker.detect(image) | |
face_landmark = compute_face_landmarks( | |
face_landmarker_result, height, width) | |
return np.array(face_landmark), height, width | |
def get_lip_mask(landmarks, height, width, out_path): | |
""" | |
Extracts the lip region from the given landmarks and saves it as an image. | |
Parameters: | |
landmarks (numpy.ndarray): Array of facial landmarks. | |
height (int): Height of the output lip mask image. | |
width (int): Width of the output lip mask image. | |
out_path (pathlib.Path): Path to save the lip mask image. | |
""" | |
lip_landmarks = np.take(landmarks, lip_ids, 0) | |
min_xy_lip = np.round(np.min(lip_landmarks, 0)) | |
max_xy_lip = np.round(np.max(lip_landmarks, 0)) | |
min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region( | |
[min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, 2.0) | |
lip_mask = np.zeros((height, width), dtype=np.uint8) | |
lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]), | |
round(min_xy_lip[0]):round(max_xy_lip[0])] = 255 | |
cv2.imwrite(str(out_path), lip_mask) | |
def get_face_mask(landmarks, height, width, out_path, expand_ratio): | |
""" | |
Generate a face mask based on the given landmarks. | |
Args: | |
landmarks (numpy.ndarray): The landmarks of the face. | |
height (int): The height of the output face mask image. | |
width (int): The width of the output face mask image. | |
out_path (pathlib.Path): The path to save the face mask image. | |
Returns: | |
None. The face mask image is saved at the specified path. | |
""" | |
face_landmarks = np.take(landmarks, silhouette_ids, 0) | |
min_xy_face = np.round(np.min(face_landmarks, 0)) | |
max_xy_face = np.round(np.max(face_landmarks, 0)) | |
min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region( | |
[min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio) | |
face_mask = np.zeros((height, width), dtype=np.uint8) | |
face_mask[round(min_xy_face[1]):round(max_xy_face[1]), | |
round(min_xy_face[0]):round(max_xy_face[0])] = 255 | |
cv2.imwrite(str(out_path), face_mask) | |
def get_mask(file, cache_dir, face_expand_raio): | |
""" | |
Generate a face mask based on the given landmarks and save it to the specified cache directory. | |
Args: | |
file (str): The path to the file containing the landmarks. | |
cache_dir (str): The directory to save the generated face mask. | |
Returns: | |
None | |
""" | |
landmarks, height, width = get_landmark(file) | |
file_name = os.path.basename(file).split(".")[0] | |
get_lip_mask(landmarks, height, width, os.path.join( | |
cache_dir, f"{file_name}_lip_mask.png")) | |
get_face_mask(landmarks, height, width, os.path.join( | |
cache_dir, f"{file_name}_face_mask.png"), face_expand_raio) | |
get_blur_mask(os.path.join( | |
cache_dir, f"{file_name}_face_mask.png"), os.path.join( | |
cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51)) | |
get_blur_mask(os.path.join( | |
cache_dir, f"{file_name}_lip_mask.png"), os.path.join( | |
cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31)) | |
get_background_mask(os.path.join( | |
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( | |
cache_dir, f"{file_name}_sep_background.png")) | |
get_sep_face_mask(os.path.join( | |
cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( | |
cache_dir, f"{file_name}_sep_lip.png"), os.path.join( | |
cache_dir, f"{file_name}_sep_face.png")) | |
def expand_region(region, image_w, image_h, expand_ratio=1.0): | |
""" | |
Expand the given region by a specified ratio. | |
Args: | |
region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region. | |
image_w (int): The width of the image. | |
image_h (int): The height of the image. | |
expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0. | |
Returns: | |
tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region. | |
""" | |
min_x, max_x, min_y, max_y = region | |
mid_x = (max_x + min_x) // 2 | |
side_len_x = (max_x - min_x) * expand_ratio | |
mid_y = (max_y + min_y) // 2 | |
side_len_y = (max_y - min_y) * expand_ratio | |
min_x = mid_x - side_len_x // 2 | |
max_x = mid_x + side_len_x // 2 | |
min_y = mid_y - side_len_y // 2 | |
max_y = mid_y + side_len_y // 2 | |
if min_x < 0: | |
max_x -= min_x | |
min_x = 0 | |
if max_x > image_w: | |
min_x -= max_x - image_w | |
max_x = image_w | |
if min_y < 0: | |
max_y -= min_y | |
min_y = 0 | |
if max_y > image_h: | |
min_y -= max_y - image_h | |
max_y = image_h | |
return round(min_x), round(max_x), round(min_y), round(max_y) | |
def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)): | |
""" | |
Read, resize, blur, normalize, and save an image. | |
Parameters: | |
file_path (str): Path to the input image file. | |
output_dir (str): Path to the output directory to save blurred images. | |
resize_dim (tuple): Dimensions to resize the images to. | |
kernel_size (tuple): Size of the kernel to use for Gaussian blur. | |
""" | |
# Read the mask image | |
mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) | |
# Check if the image is loaded successfully | |
if mask is not None: | |
# Resize the mask image | |
resized_mask = cv2.resize(mask, resize_dim) | |
# Apply Gaussian blur to the resized mask image | |
blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0) | |
# Normalize the blurred image | |
normalized_mask = cv2.normalize( | |
blurred_mask, None, 0, 255, cv2.NORM_MINMAX) | |
# Save the normalized mask image | |
cv2.imwrite(output_file_path, normalized_mask) | |
return f"Processed, normalized, and saved: {output_file_path}" | |
return f"Failed to load image: {file_path}" | |
def get_background_mask(file_path, output_file_path): | |
""" | |
Read an image, invert its values, and save the result. | |
Parameters: | |
file_path (str): Path to the input image file. | |
output_dir (str): Path to the output directory to save the inverted image. | |
""" | |
# Read the image | |
image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) | |
if image is None: | |
print(f"Failed to load image: {file_path}") | |
return | |
# Invert the image | |
inverted_image = 1.0 - ( | |
image / 255.0 | |
) # Assuming the image values are in [0, 255] range | |
# Convert back to uint8 | |
inverted_image = (inverted_image * 255).astype(np.uint8) | |
# Save the inverted image | |
cv2.imwrite(output_file_path, inverted_image) | |
print(f"Processed and saved: {output_file_path}") | |
def get_sep_face_mask(file_path1, file_path2, output_file_path): | |
""" | |
Read two images, subtract the second one from the first, and save the result. | |
Parameters: | |
output_dir (str): Path to the output directory to save the subtracted image. | |
""" | |
# Read the images | |
mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE) | |
mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE) | |
if mask1 is None or mask2 is None: | |
print(f"Failed to load images: {file_path1}") | |
return | |
# Ensure the images are the same size | |
if mask1.shape != mask2.shape: | |
print( | |
f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}" | |
) | |
return | |
# Subtract the second mask from the first | |
result_mask = cv2.subtract(mask1, mask2) | |
# Save the result mask image | |
cv2.imwrite(output_file_path, result_mask) | |
print(f"Processed and saved: {output_file_path}") | |
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): | |
p = subprocess.Popen([ | |
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file | |
]) | |
ret = p.wait() | |
assert ret == 0, "Resample audio failed!" | |
return output_audio_file | |
def get_face_region(image_path: str, detector): | |
try: | |
image = cv2.imread(image_path) | |
if image is None: | |
print(f"Failed to open image: {image_path}. Skipping...") | |
return None, None | |
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) | |
detection_result = detector.detect(mp_image) | |
# Adjust mask creation for the three-channel image | |
mask = np.zeros_like(image, dtype=np.uint8) | |
for detection in detection_result.detections: | |
bbox = detection.bounding_box | |
start_point = (int(bbox.origin_x), int(bbox.origin_y)) | |
end_point = (int(bbox.origin_x + bbox.width), | |
int(bbox.origin_y + bbox.height)) | |
cv2.rectangle(mask, start_point, end_point, | |
(255, 255, 255), thickness=-1) | |
save_path = image_path.replace("images", "face_masks") | |
os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
cv2.imwrite(save_path, mask) | |
# print(f"Processed and saved {save_path}") | |
return image_path, mask | |
except Exception as e: | |
print(f"Error processing image {image_path}: {e}") | |
return None, None | |