Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import ffmpeg | |
from datetime import datetime | |
from pathlib import Path | |
import numpy as np | |
import cv2 | |
import torch | |
import spaces | |
from scipy.spatial.transform import Rotation as R | |
from scipy.interpolate import interp1d | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from einops import repeat | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPVisionModelWithProjection | |
from src.models.pose_guider import PoseGuider | |
from src.models.unet_2d_condition import UNet2DConditionModel | |
from src.models.unet_3d import UNet3DConditionModel | |
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline | |
from src.utils.util import save_videos_grid | |
from src.audio_models.model import Audio2MeshModel | |
from src.utils.audio_util import prepare_audio_feature | |
from src.utils.mp_utils import LMKExtractor | |
from src.utils.draw_util import FaceMeshVisualizer | |
from src.utils.pose_util import project_points | |
def matrix_to_euler_and_translation(matrix): | |
rotation_matrix = matrix[:3, :3] | |
translation_vector = matrix[:3, 3] | |
rotation = R.from_matrix(rotation_matrix) | |
euler_angles = rotation.as_euler('xyz', degrees=True) | |
return euler_angles, translation_vector | |
def smooth_pose_seq(pose_seq, window_size=5): | |
smoothed_pose_seq = np.zeros_like(pose_seq) | |
for i in range(len(pose_seq)): | |
start = max(0, i - window_size // 2) | |
end = min(len(pose_seq), i + window_size // 2 + 1) | |
smoothed_pose_seq[i] = np.mean(pose_seq[start:end], axis=0) | |
return smoothed_pose_seq | |
def get_headpose_temp(input_video): | |
lmk_extractor = LMKExtractor() | |
cap = cv2.VideoCapture(input_video) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
trans_mat_list = [] | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
result = lmk_extractor(frame) | |
trans_mat_list.append(result['trans_mat'].astype(np.float32)) | |
cap.release() | |
trans_mat_arr = np.array(trans_mat_list) | |
# compute delta pose | |
trans_mat_inv_frame_0 = np.linalg.inv(trans_mat_arr[0]) | |
pose_arr = np.zeros([trans_mat_arr.shape[0], 6]) | |
for i in range(pose_arr.shape[0]): | |
pose_mat = trans_mat_inv_frame_0 @ trans_mat_arr[i] | |
euler_angles, translation_vector = matrix_to_euler_and_translation(pose_mat) | |
pose_arr[i, :3] = euler_angles | |
pose_arr[i, 3:6] = translation_vector | |
# interpolate to 30 fps | |
new_fps = 30 | |
old_time = np.linspace(0, total_frames / fps, total_frames) | |
new_time = np.linspace(0, total_frames / fps, int(total_frames * new_fps / fps)) | |
pose_arr_interp = np.zeros((len(new_time), 6)) | |
for i in range(6): | |
interp_func = interp1d(old_time, pose_arr[:, i]) | |
pose_arr_interp[:, i] = interp_func(new_time) | |
pose_arr_smooth = smooth_pose_seq(pose_arr_interp) | |
return pose_arr_smooth | |
def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=150, seed=42): | |
fps = 30 | |
cfg = 3.5 | |
config = OmegaConf.load('./configs/prompts/animation_audio.yaml') | |
if config.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
else: | |
weight_dtype = torch.float32 | |
audio_infer_config = OmegaConf.load(config.audio_inference_config) | |
# prepare model | |
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) | |
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) | |
a2m_model.cuda().eval() | |
vae = AutoencoderKL.from_pretrained( | |
config.pretrained_vae_path, | |
).to("cuda", dtype=weight_dtype) | |
reference_unet = UNet2DConditionModel.from_pretrained( | |
config.pretrained_base_model_path, | |
subfolder="unet", | |
).to(dtype=weight_dtype, device="cuda") | |
inference_config_path = config.inference_config | |
infer_config = OmegaConf.load(inference_config_path) | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
config.pretrained_base_model_path, | |
config.motion_module_path, | |
subfolder="unet", | |
unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
).to(dtype=weight_dtype, device="cuda") | |
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention | |
image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
config.image_encoder_path | |
).to(dtype=weight_dtype, device="cuda") | |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
scheduler = DDIMScheduler(**sched_kwargs) | |
generator = torch.manual_seed(seed) | |
width, height = size, size | |
# load pretrained weights | |
denoising_unet.load_state_dict( | |
torch.load(config.denoising_unet_path, map_location="cpu"), | |
strict=False, | |
) | |
reference_unet.load_state_dict( | |
torch.load(config.reference_unet_path, map_location="cpu"), | |
) | |
pose_guider.load_state_dict( | |
torch.load(config.pose_guider_path, map_location="cpu"), | |
) | |
pipe = Pose2VideoPipeline( | |
vae=vae, | |
image_encoder=image_enc, | |
reference_unet=reference_unet, | |
denoising_unet=denoising_unet, | |
pose_guider=pose_guider, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to("cuda", dtype=weight_dtype) | |
date_str = datetime.now().strftime("%Y%m%d") | |
time_str = datetime.now().strftime("%H%M") | |
save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}" | |
save_dir = Path(f"output/{date_str}/{save_dir_name}") | |
save_dir.mkdir(exist_ok=True, parents=True) | |
lmk_extractor = LMKExtractor() | |
vis = FaceMeshVisualizer(forehead_edge=False) | |
ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR) | |
# TODO: 人脸检测+裁剪 | |
ref_image_np = cv2.resize(ref_image_np, (size, size)) | |
ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB)) | |
face_result = lmk_extractor(ref_image_np) | |
if face_result is None: | |
return None | |
lmks = face_result['lmks'].astype(np.float32) | |
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) | |
sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) | |
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() | |
sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) | |
# inference | |
pred = a2m_model.infer(sample['audio_feature'], sample['seq_len']) | |
pred = pred.squeeze().detach().cpu().numpy() | |
pred = pred.reshape(pred.shape[0], -1, 3) | |
pred = pred + face_result['lmks3d'] | |
if headpose_video is not None: | |
pose_seq = get_headpose_temp(headpose_video) | |
else: | |
pose_seq = np.load(config['pose_temp']) | |
mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0) | |
cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']] | |
# project 3D mesh to 2D landmark | |
projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width]) | |
pose_images = [] | |
for i, verts in enumerate(projected_vertices): | |
lmk_img = vis.draw_landmarks((width, height), verts, normed=False) | |
pose_images.append(lmk_img) | |
pose_list = [] | |
pose_tensor_list = [] | |
pose_transform = transforms.Compose( | |
[transforms.Resize((height, width)), transforms.ToTensor()] | |
) | |
args_L = len(pose_images) if length==0 or length > len(pose_images) else length | |
for pose_image_np in pose_images[: args_L]: | |
pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB)) | |
pose_tensor_list.append(pose_transform(pose_image_pil)) | |
pose_image_np = cv2.resize(pose_image_np, (width, height)) | |
pose_list.append(pose_image_np) | |
pose_list = np.array(pose_list) | |
video_length = len(pose_tensor_list) | |
video = pipe( | |
ref_image_pil, | |
pose_list, | |
ref_pose, | |
width, | |
height, | |
video_length, | |
steps, | |
cfg, | |
generator=generator, | |
).videos | |
save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4" | |
save_videos_grid( | |
video, | |
save_path, | |
n_rows=1, | |
fps=fps, | |
) | |
stream = ffmpeg.input(save_path) | |
audio = ffmpeg.input(input_audio) | |
ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() | |
os.remove(save_path) | |
return save_path.replace('_noaudio.mp4', '.mp4') | |