Spaces:
Paused
Paused
"""run bash scripts/download_models.sh first to prepare the weights file""" | |
import os | |
import shutil | |
from argparse import Namespace | |
from src.utils.preprocess import CropAndExtract | |
from src.test_audio2coeff import Audio2Coeff | |
from src.facerender.animate import AnimateFromCoeff | |
from src.generate_batch import get_data | |
from src.generate_facerender_batch import get_facerender_data | |
from cog import BasePredictor, Input, Path | |
checkpoints = "checkpoints" | |
class Predictor(BasePredictor): | |
def setup(self): | |
"""Load the model into memory to make running multiple predictions efficient""" | |
device = "cuda" | |
path_of_lm_croper = os.path.join( | |
checkpoints, "shape_predictor_68_face_landmarks.dat" | |
) | |
path_of_net_recon_model = os.path.join(checkpoints, "epoch_20.pth") | |
dir_of_BFM_fitting = os.path.join(checkpoints, "BFM_Fitting") | |
wav2lip_checkpoint = os.path.join(checkpoints, "wav2lip.pth") | |
audio2pose_checkpoint = os.path.join(checkpoints, "auido2pose_00140-model.pth") | |
audio2pose_yaml_path = os.path.join("src", "config", "auido2pose.yaml") | |
audio2exp_checkpoint = os.path.join(checkpoints, "auido2exp_00300-model.pth") | |
audio2exp_yaml_path = os.path.join("src", "config", "auido2exp.yaml") | |
free_view_checkpoint = os.path.join( | |
checkpoints, "facevid2vid_00189-model.pth.tar" | |
) | |
# init model | |
self.preprocess_model = CropAndExtract( | |
path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device | |
) | |
self.audio_to_coeff = Audio2Coeff( | |
audio2pose_checkpoint, | |
audio2pose_yaml_path, | |
audio2exp_checkpoint, | |
audio2exp_yaml_path, | |
wav2lip_checkpoint, | |
device, | |
) | |
self.animate_from_coeff = { | |
"full": AnimateFromCoeff( | |
free_view_checkpoint, | |
os.path.join(checkpoints, "mapping_00109-model.pth.tar"), | |
os.path.join("src", "config", "facerender_still.yaml"), | |
device, | |
), | |
"others": AnimateFromCoeff( | |
free_view_checkpoint, | |
os.path.join(checkpoints, "mapping_00229-model.pth.tar"), | |
os.path.join("src", "config", "facerender.yaml"), | |
device, | |
), | |
} | |
def predict( | |
self, | |
source_image: Path = Input( | |
description="Upload the source image, it can be video.mp4 or picture.png", | |
), | |
driven_audio: Path = Input( | |
description="Upload the driven audio, accepts .wav and .mp4 file", | |
), | |
enhancer: str = Input( | |
description="Choose a face enhancer", | |
choices=["gfpgan", "RestoreFormer"], | |
default="gfpgan", | |
), | |
preprocess: str = Input( | |
description="how to preprocess the images", | |
choices=["crop", "resize", "full"], | |
default="full", | |
), | |
ref_eyeblink: Path = Input( | |
description="path to reference video providing eye blinking", | |
default=None, | |
), | |
ref_pose: Path = Input( | |
description="path to reference video providing pose", | |
default=None, | |
), | |
still: bool = Input( | |
description="can crop back to the original videos for the full body aniamtion when preprocess is full", | |
default=True, | |
), | |
) -> Path: | |
"""Run a single prediction on the model""" | |
animate_from_coeff = ( | |
self.animate_from_coeff["full"] | |
if preprocess == "full" | |
else self.animate_from_coeff["others"] | |
) | |
args = load_default() | |
args.pic_path = str(source_image) | |
args.audio_path = str(driven_audio) | |
device = "cuda" | |
args.still = still | |
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink) | |
args.ref_pose = None if ref_pose is None else str(ref_pose) | |
# crop image and extract 3dmm from image | |
results_dir = "results" | |
if os.path.exists(results_dir): | |
shutil.rmtree(results_dir) | |
os.makedirs(results_dir) | |
first_frame_dir = os.path.join(results_dir, "first_frame_dir") | |
os.makedirs(first_frame_dir) | |
print("3DMM Extraction for source image") | |
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate( | |
args.pic_path, first_frame_dir, preprocess, source_image_flag=True | |
) | |
if first_coeff_path is None: | |
print("Can't get the coeffs of the input") | |
return | |
if ref_eyeblink is not None: | |
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[ | |
0 | |
] | |
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname) | |
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True) | |
print("3DMM Extraction for the reference video providing eye blinking") | |
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate( | |
ref_eyeblink, ref_eyeblink_frame_dir | |
) | |
else: | |
ref_eyeblink_coeff_path = None | |
if ref_pose is not None: | |
if ref_pose == ref_eyeblink: | |
ref_pose_coeff_path = ref_eyeblink_coeff_path | |
else: | |
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0] | |
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname) | |
os.makedirs(ref_pose_frame_dir, exist_ok=True) | |
print("3DMM Extraction for the reference video providing pose") | |
ref_pose_coeff_path, _, _ = self.preprocess_model.generate( | |
ref_pose, ref_pose_frame_dir | |
) | |
else: | |
ref_pose_coeff_path = None | |
# audio2ceoff | |
batch = get_data( | |
first_coeff_path, | |
args.audio_path, | |
device, | |
ref_eyeblink_coeff_path, | |
still=still, | |
) | |
coeff_path = self.audio_to_coeff.generate( | |
batch, results_dir, args.pose_style, ref_pose_coeff_path | |
) | |
# coeff2video | |
print("coeff2video") | |
data = get_facerender_data( | |
coeff_path, | |
crop_pic_path, | |
first_coeff_path, | |
args.audio_path, | |
args.batch_size, | |
args.input_yaw, | |
args.input_pitch, | |
args.input_roll, | |
expression_scale=args.expression_scale, | |
still_mode=still, | |
preprocess=preprocess, | |
) | |
animate_from_coeff.generate( | |
data, results_dir, args.pic_path, crop_info, | |
enhancer=enhancer, background_enhancer=args.background_enhancer, | |
preprocess=preprocess) | |
output = "/tmp/out.mp4" | |
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0]) | |
shutil.copy(mp4_path, output) | |
return Path(output) | |
def load_default(): | |
return Namespace( | |
pose_style=0, | |
batch_size=2, | |
expression_scale=1.0, | |
input_yaw=None, | |
input_pitch=None, | |
input_roll=None, | |
background_enhancer=None, | |
face3dvis=False, | |
net_recon="resnet50", | |
init_path=None, | |
use_last_fc=False, | |
bfm_folder="./checkpoints/BFM_Fitting/", | |
bfm_model="BFM_model_front.mat", | |
focal=1015.0, | |
center=112.0, | |
camera_d=10.0, | |
z_near=5.0, | |
z_far=15.0, | |
) | |