DmitrMakeev
commited on
Commit
•
c626b55
1
Parent(s):
183ec7c
Upload 7 files
Browse files- .gitattributes +59 -0
- README.md +59 -0
- config.py +3 -0
- image_preprocess.py +57 -0
- phindex.json +1 -0
- requirements.txt +8 -0
- test_script.py +180 -0
.gitattributes
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
OpenFace/FaceLandmarkVidMulti filter=lfs diff=lfs merge=lfs -text
|
36 |
+
OpenFace/FeatureExtraction filter=lfs diff=lfs merge=lfs -text
|
37 |
+
OpenFace/FaceLandmarkVid filter=lfs diff=lfs merge=lfs -text
|
38 |
+
OpenFace/FaceLandmarkImg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
OpenFace/model/detection_validation/validator_cnn.txt filter=lfs diff=lfs merge=lfs -text
|
40 |
+
OpenFace/model/detection_validation/validator_cnn_68.txt filter=lfs diff=lfs merge=lfs -text
|
41 |
+
OpenFace/model/model_inner/patch_experts/ccnf_patches_1.00_inner.txt filter=lfs diff=lfs merge=lfs -text
|
42 |
+
OpenFace/model/patch_experts/ccnf_patches_0.5_wild.txt filter=lfs diff=lfs merge=lfs -text
|
43 |
+
OpenFace/model/patch_experts/ccnf_patches_1_wild.txt filter=lfs diff=lfs merge=lfs -text
|
44 |
+
OpenFace/model/patch_experts/ccnf_patches_0.35_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
|
45 |
+
OpenFace/model/patch_experts/ccnf_patches_0.35_wild.txt filter=lfs diff=lfs merge=lfs -text
|
46 |
+
OpenFace/model/patch_experts/ccnf_patches_0.25_wild.txt filter=lfs diff=lfs merge=lfs -text
|
47 |
+
OpenFace/model/patch_experts/ccnf_patches_0.5_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
|
48 |
+
OpenFace/model/patch_experts/ccnf_patches_0.25_multi_pie.txt filter=lfs diff=lfs merge=lfs -text
|
49 |
+
OpenFace/model/patch_experts/ccnf_patches_0.5_general.txt filter=lfs diff=lfs merge=lfs -text
|
50 |
+
OpenFace/model/patch_experts/ccnf_patches_0.25_general.txt filter=lfs diff=lfs merge=lfs -text
|
51 |
+
OpenFace/model/patch_experts/ccnf_patches_0.35_general.txt filter=lfs diff=lfs merge=lfs -text
|
52 |
+
OpenFace/model/mtcnn_detector/ONet.dat filter=lfs diff=lfs merge=lfs -text
|
53 |
+
samples/audios/trump.wav filter=lfs diff=lfs merge=lfs -text
|
54 |
+
samples/audios/abstract.wav filter=lfs diff=lfs merge=lfs -text
|
55 |
+
samples/audios/obama2.wav filter=lfs diff=lfs merge=lfs -text
|
56 |
+
OpenFace/model/patch_experts/cen_patches_0.35_of.dat filter=lfs diff=lfs merge=lfs -text
|
57 |
+
OpenFace/model/patch_experts/cen_patches_0.25_of.dat filter=lfs diff=lfs merge=lfs -text
|
58 |
+
OpenFace/model/patch_experts/cen_patches_1.00_of.dat filter=lfs diff=lfs merge=lfs -text
|
59 |
+
OpenFace/model/patch_experts/cen_patches_0.50_of.dat filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning (AAAI 2022)
|
2 |
+
|
3 |
+
#### [Paper](https://arxiv.org/pdf/2112.02749.pdf) | [Demo](https://www.youtube.com/watch?v=HHj-XCXXePY)
|
4 |
+
|
5 |
+
#### Requirements
|
6 |
+
|
7 |
+
- Python >= 3.6 , Pytorch >= 1.8 and ffmpeg
|
8 |
+
- Set up [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace)
|
9 |
+
- We use the OpenFace tools to extract the initial pose of the reference image
|
10 |
+
- Make sure you have installed this tool, and set the `OPENFACE_POSE_EXTRACTOR_PATH` in `config.py`. For example, it should be the absolute path of the "`FeatureExtraction.exe`" for Windows.
|
11 |
+
- Other requirements are listed in the 'requirements.txt'
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
#### Pretrained Checkpoint
|
16 |
+
|
17 |
+
Please download the pretrained checkpoint from [google-drive](https://drive.google.com/file/d/1mjFEozPR_2vMaVRMd9Agk_sU1VaiUYMl/view?usp=sharing) and unzip it to the directory (`/checkpoints`). Or manually modify the settings of `GENERATOR_CKPT` and `AUDIO2POSE_CKPT` in the `config.py`.
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
#### Extract phoneme
|
22 |
+
|
23 |
+
We employ the [CMU phoneset](https://github.com/cmusphinx/cmudict) to represent phonemes, the extra 'SIL' means silence. All the phonesets can be seen in '`phindex.json`'.
|
24 |
+
|
25 |
+
We have extracted the phonemes for the audios in the '`sample/audio`' directory. For other audios, you can extract the phonemes by other ASR tools and then map them to the CMU phoneset. Or email to wangsuzhen@corp.netease.com for help.
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
#### Generate Demo Results
|
30 |
+
|
31 |
+
```
|
32 |
+
python test_script.py --img_path xxx.jpg --audio_path xxx.wav --phoneme_path xxx.json --save_dir "YOUR_DIR"
|
33 |
+
```
|
34 |
+
|
35 |
+
Note that the input images must keep the same height and width and the face should be appropriately cropped as in `samples/imgs`. You can also preprocess your images with `image_preprocess.py`.
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
#### License and Citation
|
40 |
+
|
41 |
+
```
|
42 |
+
@InProceedings{wang2021one,
|
43 |
+
author = Suzhen Wang, Lincheng Li, Yu Ding, Xin Yu
|
44 |
+
title = {One-shot Talking Face Generation from Single-speaker Audio-Visual Correlation Learning},
|
45 |
+
booktitle = {AAAI 2022},
|
46 |
+
year = {2022},
|
47 |
+
}
|
48 |
+
```
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
#### Acknowledgement
|
53 |
+
|
54 |
+
This codebase is based on [First Order Motion Model](https://github.com/AliaksandrSiarohin/first-order-model) and [imaginaire](https://github.com/NVlabs/imaginaire), thanks for their contributions.
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
config.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
OPENFACE_POSE_EXTRACTOR_PATH = "/content/one-shot-talking-face/OpenFace/FeatureExtraction"
|
2 |
+
GENERATOR_CKPT = "/content/one-shot-talking-face/checkpoints/generator.ckpt"
|
3 |
+
AUDIO2POSE_CKPT = "/content/one-shot-talking-face/checkpoints/audio2pose.ckpt"
|
image_preprocess.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dlib
|
2 |
+
import cv2
|
3 |
+
def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
|
4 |
+
left, top, right, bot = bbox
|
5 |
+
width = right - left
|
6 |
+
height = bot - top
|
7 |
+
|
8 |
+
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
9 |
+
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
10 |
+
|
11 |
+
left_t = int(left - width_increase * width)
|
12 |
+
top_t = int(top - height_increase * height)
|
13 |
+
right_t = int(right + width_increase * width)
|
14 |
+
bot_t = int(bot + height_increase * height)
|
15 |
+
|
16 |
+
left_oob = -min(0, left_t)
|
17 |
+
right_oob = right - min(right_t, w)
|
18 |
+
top_oob = -min(0, top_t)
|
19 |
+
bot_oob = bot - min(bot_t, h)
|
20 |
+
|
21 |
+
if max(left_oob, right_oob, top_oob, bot_oob) > 0:
|
22 |
+
max_w = max(left_oob, right_oob)
|
23 |
+
max_h = max(top_oob, bot_oob)
|
24 |
+
if max_w > max_h:
|
25 |
+
return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
|
26 |
+
else:
|
27 |
+
return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
|
28 |
+
|
29 |
+
else:
|
30 |
+
return (left_t, top_t, right_t, bot_t)
|
31 |
+
|
32 |
+
def crop_src_image(src_img,save_img, detector=None):
|
33 |
+
if detector is None:
|
34 |
+
detector = dlib.get_frontal_face_detector()
|
35 |
+
|
36 |
+
img = cv2.imread(src_img)
|
37 |
+
faces = detector(img, 0)
|
38 |
+
h, width, _ = img.shape
|
39 |
+
if len(faces) > 0:
|
40 |
+
bbox = [faces[0].left(), faces[0].top(),faces[0].right(), faces[0].bottom()]
|
41 |
+
l = bbox[3]-bbox[1]
|
42 |
+
bbox[1]= bbox[1]-l*0.1
|
43 |
+
bbox[3]= bbox[3]-l*0.1
|
44 |
+
bbox[1] = max(0,bbox[1])
|
45 |
+
bbox[3] = min(h,bbox[3])
|
46 |
+
bbox = compute_aspect_preserved_bbox(tuple(bbox), 0.5, img.shape[0], img.shape[1])
|
47 |
+
img = img[bbox[1] :bbox[3] , bbox[0]:bbox[2]]
|
48 |
+
img = cv2.resize(img, (256, 256))
|
49 |
+
cv2.imwrite(save_img,img)
|
50 |
+
else:
|
51 |
+
img = cv2.resize(img,(256,256))
|
52 |
+
cv2.imwrite(save_img, img)
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
src_img = ""
|
56 |
+
out_img = ""
|
57 |
+
crop_src_image(src_img,out_img)
|
phindex.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5, "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11, "EY": 12, "F": 13, "G": 14, "HH": 15, "IH": 16, "IY": 17, "JH": 18, "K": 19, "L": 20, "M": 21, "N": 22, "NG": 23, "NSN": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "SIL": 31, "T": 32, "TH": 33, "UH": 34, "UW": 35, "V": 36, "W": 37, "Y": 38, "Z": 39, "ZH": 40}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scikit-image
|
2 |
+
python_speech_features
|
3 |
+
pyworld
|
4 |
+
pyyaml
|
5 |
+
imageio
|
6 |
+
scipy
|
7 |
+
pyworld
|
8 |
+
opencv-python
|
test_script.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
from models.generator import OcclusionAwareGenerator
|
6 |
+
from models.keypoint_detector import KPDetector
|
7 |
+
import argparse
|
8 |
+
import imageio
|
9 |
+
from models.util import draw_annotation_box
|
10 |
+
from models.transformer import Audio2kpTransformer
|
11 |
+
from scipy.io import wavfile
|
12 |
+
from tools.interface import read_img,get_img_pose,get_pose_from_audio,get_audio_feature_from_audio,\
|
13 |
+
parse_phoneme_file,load_ckpt
|
14 |
+
import config
|
15 |
+
|
16 |
+
def normalize_kp(kp_source, kp_driving, kp_driving_initial,
|
17 |
+
use_relative_movement=True, use_relative_jacobian=True):
|
18 |
+
|
19 |
+
kp_new = {k: v for k, v in kp_driving.items()}
|
20 |
+
if use_relative_movement:
|
21 |
+
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
22 |
+
# kp_value_diff *= adapt_movement_scale
|
23 |
+
kp_new['value'] = kp_value_diff + kp_source['value']
|
24 |
+
|
25 |
+
if use_relative_jacobian:
|
26 |
+
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
27 |
+
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
28 |
+
|
29 |
+
return kp_new
|
30 |
+
|
31 |
+
|
32 |
+
def test_with_input_audio_and_image(img_path, audio_path,phs, generator_ckpt, audio2pose_ckpt, save_dir="samples/results"):
|
33 |
+
with open("config_file/vox-256.yaml") as f:
|
34 |
+
config = yaml.full_load(f)
|
35 |
+
# temp_audio = audio_path
|
36 |
+
# print(audio_path)
|
37 |
+
cur_path = os.getcwd()
|
38 |
+
|
39 |
+
sr,_ = wavfile.read(audio_path)
|
40 |
+
if sr!=16000:
|
41 |
+
temp_audio = os.path.join(cur_path,"samples","temp.wav")
|
42 |
+
command = "ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (audio_path, temp_audio)
|
43 |
+
os.system(command)
|
44 |
+
else:
|
45 |
+
temp_audio = audio_path
|
46 |
+
|
47 |
+
|
48 |
+
opt = argparse.Namespace(**yaml.full_load(open("config_file/audio2kp.yaml")))
|
49 |
+
|
50 |
+
img = read_img(img_path).cuda()
|
51 |
+
|
52 |
+
first_pose = get_img_pose(img_path)#.cuda()
|
53 |
+
|
54 |
+
audio_feature = get_audio_feature_from_audio(temp_audio)
|
55 |
+
frames = len(audio_feature) // 4
|
56 |
+
frames = min(frames,len(phs["phone_list"]))
|
57 |
+
|
58 |
+
tp = np.zeros([256, 256], dtype=np.float32)
|
59 |
+
draw_annotation_box(tp, first_pose[:3], first_pose[3:])
|
60 |
+
tp = torch.from_numpy(tp).unsqueeze(0).unsqueeze(0).cuda()
|
61 |
+
ref_pose = get_pose_from_audio(tp, audio_feature, audio2pose_ckpt)
|
62 |
+
torch.cuda.empty_cache()
|
63 |
+
trans_seq = ref_pose[:, 3:]
|
64 |
+
rot_seq = ref_pose[:, :3]
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
audio_seq = audio_feature#[40:]
|
69 |
+
ph_seq = phs["phone_list"]
|
70 |
+
|
71 |
+
|
72 |
+
ph_frames = []
|
73 |
+
audio_frames = []
|
74 |
+
pose_frames = []
|
75 |
+
name_len = frames
|
76 |
+
|
77 |
+
pad = np.zeros((4, audio_seq.shape[1]), dtype=np.float32)
|
78 |
+
|
79 |
+
for rid in range(0, frames):
|
80 |
+
ph = []
|
81 |
+
audio = []
|
82 |
+
pose = []
|
83 |
+
for i in range(rid - opt.num_w, rid + opt.num_w + 1):
|
84 |
+
if i < 0:
|
85 |
+
rot = rot_seq[0]
|
86 |
+
trans = trans_seq[0]
|
87 |
+
ph.append(31)
|
88 |
+
audio.append(pad)
|
89 |
+
elif i >= name_len:
|
90 |
+
ph.append(31)
|
91 |
+
rot = rot_seq[name_len - 1]
|
92 |
+
trans = trans_seq[name_len - 1]
|
93 |
+
audio.append(pad)
|
94 |
+
else:
|
95 |
+
ph.append(ph_seq[i])
|
96 |
+
rot = rot_seq[i]
|
97 |
+
trans = trans_seq[i]
|
98 |
+
audio.append(audio_seq[i * 4:i * 4 + 4])
|
99 |
+
tmp_pose = np.zeros([256, 256])
|
100 |
+
draw_annotation_box(tmp_pose, np.array(rot), np.array(trans))
|
101 |
+
pose.append(tmp_pose)
|
102 |
+
|
103 |
+
ph_frames.append(ph)
|
104 |
+
audio_frames.append(audio)
|
105 |
+
pose_frames.append(pose)
|
106 |
+
|
107 |
+
audio_f = torch.from_numpy(np.array(audio_frames,dtype=np.float32)).unsqueeze(0)
|
108 |
+
poses = torch.from_numpy(np.array(pose_frames, dtype=np.float32)).unsqueeze(0)
|
109 |
+
ph_frames = torch.from_numpy(np.array(ph_frames)).unsqueeze(0)
|
110 |
+
bs = audio_f.shape[1]
|
111 |
+
predictions_gen = []
|
112 |
+
|
113 |
+
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
114 |
+
**config['model_params']['common_params'])
|
115 |
+
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
116 |
+
**config['model_params']['common_params'])
|
117 |
+
kp_detector = kp_detector.cuda()
|
118 |
+
generator = generator.cuda()
|
119 |
+
|
120 |
+
ph2kp = Audio2kpTransformer(opt).cuda()
|
121 |
+
|
122 |
+
load_ckpt(generator_ckpt, kp_detector=kp_detector, generator=generator,ph2kp=ph2kp)
|
123 |
+
|
124 |
+
|
125 |
+
ph2kp.eval()
|
126 |
+
generator.eval()
|
127 |
+
kp_detector.eval()
|
128 |
+
|
129 |
+
with torch.no_grad():
|
130 |
+
for frame_idx in range(bs):
|
131 |
+
t = {}
|
132 |
+
|
133 |
+
t["audio"] = audio_f[:, frame_idx].cuda()
|
134 |
+
t["pose"] = poses[:, frame_idx].cuda()
|
135 |
+
t["ph"] = ph_frames[:,frame_idx].cuda()
|
136 |
+
t["id_img"] = img
|
137 |
+
|
138 |
+
kp_gen_source = kp_detector(img, True)
|
139 |
+
|
140 |
+
gen_kp = ph2kp(t,kp_gen_source)
|
141 |
+
if frame_idx == 0:
|
142 |
+
drive_first = gen_kp
|
143 |
+
|
144 |
+
norm = normalize_kp(kp_source=kp_gen_source, kp_driving=gen_kp, kp_driving_initial=drive_first)
|
145 |
+
out_gen = generator(img, kp_source=kp_gen_source, kp_driving=norm)
|
146 |
+
|
147 |
+
predictions_gen.append(
|
148 |
+
(np.transpose(out_gen['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8))
|
149 |
+
|
150 |
+
|
151 |
+
log_dir = save_dir
|
152 |
+
os.makedirs(os.path.join(log_dir, "temp"),exist_ok=True)
|
153 |
+
|
154 |
+
f_name = os.path.basename(img_path)[:-4] + "_" + os.path.basename(audio_path)[:-4] + ".mp4"
|
155 |
+
# kwargs = {'duration': 1. / 25.0}
|
156 |
+
video_path = os.path.join(log_dir, "temp", f_name)
|
157 |
+
print("save video to: ", video_path)
|
158 |
+
imageio.mimsave(video_path, predictions_gen, fps=25.0)
|
159 |
+
|
160 |
+
# audio_path = os.path.join(audio_dir, x['name'][0].replace(".mp4", ".wav"))
|
161 |
+
save_video = os.path.join(log_dir, f_name)
|
162 |
+
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video_path, audio_path, save_video)
|
163 |
+
os.system(cmd)
|
164 |
+
os.remove(video_path)
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == '__main__':
|
172 |
+
argparser = argparse.ArgumentParser()
|
173 |
+
argparser.add_argument("--img_path", type=str, default=None, help="path of the input image ( .jpg ), preprocessed by image_preprocess.py")
|
174 |
+
argparser.add_argument("--audio_path", type=str, default=None, help="path of the input audio ( .wav )")
|
175 |
+
argparser.add_argument("--phoneme_path", type=str, default=None, help="path of the input phoneme. It should be note that the phoneme must be consistent with the input audio")
|
176 |
+
argparser.add_argument("--save_dir", type=str, default="samples/results", help="path of the output video")
|
177 |
+
args = argparser.parse_args()
|
178 |
+
|
179 |
+
phoneme = parse_phoneme_file(args.phoneme_path)
|
180 |
+
test_with_input_audio_and_image(args.img_path,args.audio_path,phoneme,config.GENERATOR_CKPT,config.AUDIO2POSE_CKPT,args.save_dir)
|