one-shot-talking-face / test_script.py
DmitrMakeev's picture
Upload 7 files
c626b55
import os
import numpy as np
import torch
import yaml
from models.generator import OcclusionAwareGenerator
from models.keypoint_detector import KPDetector
import argparse
import imageio
from models.util import draw_annotation_box
from models.transformer import Audio2kpTransformer
from scipy.io import wavfile
from tools.interface import read_img,get_img_pose,get_pose_from_audio,get_audio_feature_from_audio,\
parse_phoneme_file,load_ckpt
import config
def normalize_kp(kp_source, kp_driving, kp_driving_initial,
use_relative_movement=True, use_relative_jacobian=True):
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
# kp_value_diff *= adapt_movement_scale
kp_new['value'] = kp_value_diff + kp_source['value']
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
def test_with_input_audio_and_image(img_path, audio_path,phs, generator_ckpt, audio2pose_ckpt, save_dir="samples/results"):
with open("config_file/vox-256.yaml") as f:
config = yaml.full_load(f)
# temp_audio = audio_path
# print(audio_path)
cur_path = os.getcwd()
sr,_ = wavfile.read(audio_path)
if sr!=16000:
temp_audio = os.path.join(cur_path,"samples","temp.wav")
command = "ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (audio_path, temp_audio)
os.system(command)
else:
temp_audio = audio_path
opt = argparse.Namespace(**yaml.full_load(open("config_file/audio2kp.yaml")))
img = read_img(img_path).cuda()
first_pose = get_img_pose(img_path)#.cuda()
audio_feature = get_audio_feature_from_audio(temp_audio)
frames = len(audio_feature) // 4
frames = min(frames,len(phs["phone_list"]))
tp = np.zeros([256, 256], dtype=np.float32)
draw_annotation_box(tp, first_pose[:3], first_pose[3:])
tp = torch.from_numpy(tp).unsqueeze(0).unsqueeze(0).cuda()
ref_pose = get_pose_from_audio(tp, audio_feature, audio2pose_ckpt)
torch.cuda.empty_cache()
trans_seq = ref_pose[:, 3:]
rot_seq = ref_pose[:, :3]
audio_seq = audio_feature#[40:]
ph_seq = phs["phone_list"]
ph_frames = []
audio_frames = []
pose_frames = []
name_len = frames
pad = np.zeros((4, audio_seq.shape[1]), dtype=np.float32)
for rid in range(0, frames):
ph = []
audio = []
pose = []
for i in range(rid - opt.num_w, rid + opt.num_w + 1):
if i < 0:
rot = rot_seq[0]
trans = trans_seq[0]
ph.append(31)
audio.append(pad)
elif i >= name_len:
ph.append(31)
rot = rot_seq[name_len - 1]
trans = trans_seq[name_len - 1]
audio.append(pad)
else:
ph.append(ph_seq[i])
rot = rot_seq[i]
trans = trans_seq[i]
audio.append(audio_seq[i * 4:i * 4 + 4])
tmp_pose = np.zeros([256, 256])
draw_annotation_box(tmp_pose, np.array(rot), np.array(trans))
pose.append(tmp_pose)
ph_frames.append(ph)
audio_frames.append(audio)
pose_frames.append(pose)
audio_f = torch.from_numpy(np.array(audio_frames,dtype=np.float32)).unsqueeze(0)
poses = torch.from_numpy(np.array(pose_frames, dtype=np.float32)).unsqueeze(0)
ph_frames = torch.from_numpy(np.array(ph_frames)).unsqueeze(0)
bs = audio_f.shape[1]
predictions_gen = []
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
kp_detector = kp_detector.cuda()
generator = generator.cuda()
ph2kp = Audio2kpTransformer(opt).cuda()
load_ckpt(generator_ckpt, kp_detector=kp_detector, generator=generator,ph2kp=ph2kp)
ph2kp.eval()
generator.eval()
kp_detector.eval()
with torch.no_grad():
for frame_idx in range(bs):
t = {}
t["audio"] = audio_f[:, frame_idx].cuda()
t["pose"] = poses[:, frame_idx].cuda()
t["ph"] = ph_frames[:,frame_idx].cuda()
t["id_img"] = img
kp_gen_source = kp_detector(img, True)
gen_kp = ph2kp(t,kp_gen_source)
if frame_idx == 0:
drive_first = gen_kp
norm = normalize_kp(kp_source=kp_gen_source, kp_driving=gen_kp, kp_driving_initial=drive_first)
out_gen = generator(img, kp_source=kp_gen_source, kp_driving=norm)
predictions_gen.append(
(np.transpose(out_gen['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))
log_dir = save_dir
os.makedirs(os.path.join(log_dir, "temp"),exist_ok=True)
f_name = os.path.basename(img_path)[:-4] + "_" + os.path.basename(audio_path)[:-4] + ".mp4"
# kwargs = {'duration': 1. / 25.0}
video_path = os.path.join(log_dir, "temp", f_name)
print("save video to: ", video_path)
imageio.mimsave(video_path, predictions_gen, fps=25.0)
# audio_path = os.path.join(audio_dir, x['name'][0].replace(".mp4", ".wav"))
save_video = os.path.join(log_dir, f_name)
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video_path, audio_path, save_video)
os.system(cmd)
os.remove(video_path)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument("--img_path", type=str, default=None, help="path of the input image ( .jpg ), preprocessed by image_preprocess.py")
argparser.add_argument("--audio_path", type=str, default=None, help="path of the input audio ( .wav )")
argparser.add_argument("--phoneme_path", type=str, default=None, help="path of the input phoneme. It should be note that the phoneme must be consistent with the input audio")
argparser.add_argument("--save_dir", type=str, default="samples/results", help="path of the output video")
args = argparser.parse_args()
phoneme = parse_phoneme_file(args.phoneme_path)
test_with_input_audio_and_image(args.img_path,args.audio_path,phoneme,config.GENERATOR_CKPT,config.AUDIO2POSE_CKPT,args.save_dir)