Spanicin commited on
Commit
d8431dd
1 Parent(s): 5b1ae50

Upload 6 files

Browse files
videoretalking/utils/alignment_stit.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import PIL.Image
3
+ import dlib
4
+ import face_alignment
5
+ import numpy as np
6
+ import scipy
7
+ import scipy.ndimage
8
+ import skimage.io as io
9
+ import torch
10
+ from PIL import Image
11
+ from scipy.ndimage import gaussian_filter1d
12
+ from tqdm import tqdm
13
+
14
+ # from configs import paths_config
15
+ def paste_image(inverse_transform, img, orig_image):
16
+ pasted_image = orig_image.copy().convert('RGBA')
17
+ projected = img.convert('RGBA').transform(orig_image.size, Image.PERSPECTIVE, inverse_transform, Image.BILINEAR)
18
+ pasted_image.paste(projected, (0, 0), mask=projected)
19
+ return pasted_image
20
+
21
+ def get_landmark(filepath, predictor, detector=None, fa=None):
22
+ """get landmark with dlib
23
+ :return: np.array shape=(68, 2)
24
+ """
25
+ if fa is not None:
26
+ image = io.imread(filepath)
27
+ lms, _, bboxes = fa.get_landmarks(image, return_bboxes=True)
28
+ if len(lms) == 0:
29
+ return None
30
+ return lms[0]
31
+
32
+ if detector is None:
33
+ detector = dlib.get_frontal_face_detector()
34
+ if isinstance(filepath, PIL.Image.Image):
35
+ img = np.array(filepath)
36
+ else:
37
+ img = dlib.load_rgb_image(filepath)
38
+ dets = detector(img)
39
+
40
+ for k, d in enumerate(dets):
41
+ shape = predictor(img, d)
42
+ break
43
+ else:
44
+ return None
45
+ t = list(shape.parts())
46
+ a = []
47
+ for tt in t:
48
+ a.append([tt.x, tt.y])
49
+ lm = np.array(a)
50
+ return lm
51
+
52
+
53
+ def align_face(filepath_or_image, predictor, output_size, detector=None,
54
+ enable_padding=False, scale=1.0):
55
+ """
56
+ :param filepath: str
57
+ :return: PIL Image
58
+ """
59
+
60
+ c, x, y = compute_transform(filepath_or_image, predictor, detector=detector,
61
+ scale=scale)
62
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
63
+ img = crop_image(filepath_or_image, output_size, quad, enable_padding=enable_padding)
64
+
65
+ # Return aligned image.
66
+ return img
67
+
68
+
69
+ def crop_image(filepath, output_size, quad, enable_padding=False):
70
+ x = (quad[3] - quad[1]) / 2
71
+ qsize = np.hypot(*x) * 2
72
+ # read image
73
+ if isinstance(filepath, PIL.Image.Image):
74
+ img = filepath
75
+ else:
76
+ img = PIL.Image.open(filepath)
77
+ transform_size = output_size
78
+ # Shrink.
79
+ shrink = int(np.floor(qsize / output_size * 0.5))
80
+ if shrink > 1:
81
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
82
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
83
+ quad /= shrink
84
+ qsize /= shrink
85
+ # Crop.
86
+ border = max(int(np.rint(qsize * 0.1)), 3)
87
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
88
+ int(np.ceil(max(quad[:, 1]))))
89
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
90
+ min(crop[3] + border, img.size[1]))
91
+ if (crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]):
92
+ img = img.crop(crop)
93
+ quad -= crop[0:2]
94
+ # Pad.
95
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
96
+ int(np.ceil(max(quad[:, 1]))))
97
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
98
+ max(pad[3] - img.size[1] + border, 0))
99
+ if enable_padding and max(pad) > border - 4:
100
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
101
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
102
+ h, w, _ = img.shape
103
+ y, x, _ = np.ogrid[:h, :w, :1]
104
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
105
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
106
+ blur = qsize * 0.02
107
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
109
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
110
+ quad += pad[:2]
111
+ # Transform.
112
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
113
+ if output_size < transform_size:
114
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
115
+ return img
116
+
117
+ def compute_transform(lm, predictor, detector=None, scale=1.0, fa=None):
118
+ # lm = get_landmark(filepath, predictor, detector, fa)
119
+ # if lm is None:
120
+ # raise Exception(f'Did not detect any faces in image: {filepath}')
121
+ lm_chin = lm[0: 17] # left-right
122
+ lm_eyebrow_left = lm[17: 22] # left-right
123
+ lm_eyebrow_right = lm[22: 27] # left-right
124
+ lm_nose = lm[27: 31] # top-down
125
+ lm_nostrils = lm[31: 36] # top-down
126
+ lm_eye_left = lm[36: 42] # left-clockwise
127
+ lm_eye_right = lm[42: 48] # left-clockwise
128
+ lm_mouth_outer = lm[48: 60] # left-clockwise
129
+ lm_mouth_inner = lm[60: 68] # left-clockwise
130
+ # Calculate auxiliary vectors.
131
+ eye_left = np.mean(lm_eye_left, axis=0)
132
+ eye_right = np.mean(lm_eye_right, axis=0)
133
+ eye_avg = (eye_left + eye_right) * 0.5
134
+ eye_to_eye = eye_right - eye_left
135
+ mouth_left = lm_mouth_outer[0]
136
+ mouth_right = lm_mouth_outer[6]
137
+ mouth_avg = (mouth_left + mouth_right) * 0.5
138
+ eye_to_mouth = mouth_avg - eye_avg
139
+ # Choose oriented crop rectangle.
140
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
141
+ x /= np.hypot(*x)
142
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
143
+
144
+ x *= scale
145
+ y = np.flipud(x) * [-1, 1]
146
+ c = eye_avg + eye_to_mouth * 0.1
147
+ return c, x, y
148
+
149
+
150
+ def crop_faces(IMAGE_SIZE, files, scale, center_sigma=0.0, xy_sigma=0.0, use_fa=False, fa=None):
151
+ if use_fa:
152
+ if fa == None:
153
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
154
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True, device=device)
155
+ predictor = None
156
+ detector = None
157
+ else:
158
+ fa = None
159
+ predictor = None
160
+ detector = None
161
+ # predictor = dlib.shape_predictor(paths_config.shape_predictor_path)
162
+ # detector = dlib.get_frontal_face_detector()
163
+
164
+ cs, xs, ys = [], [], []
165
+ for lm, pil in tqdm(files):
166
+ c, x, y = compute_transform(lm, predictor, detector=detector,
167
+ scale=scale, fa=fa)
168
+ cs.append(c)
169
+ xs.append(x)
170
+ ys.append(y)
171
+
172
+ cs = np.stack(cs)
173
+ xs = np.stack(xs)
174
+ ys = np.stack(ys)
175
+ if center_sigma != 0:
176
+ cs = gaussian_filter1d(cs, sigma=center_sigma, axis=0)
177
+
178
+ if xy_sigma != 0:
179
+ xs = gaussian_filter1d(xs, sigma=xy_sigma, axis=0)
180
+ ys = gaussian_filter1d(ys, sigma=xy_sigma, axis=0)
181
+
182
+ quads = np.stack([cs - xs - ys, cs - xs + ys, cs + xs + ys, cs + xs - ys], axis=1)
183
+ quads = list(quads)
184
+
185
+ crops, orig_images = crop_faces_by_quads(IMAGE_SIZE, files, quads)
186
+
187
+ return crops, orig_images, quads
188
+
189
+
190
+ def crop_faces_by_quads(IMAGE_SIZE, files, quads):
191
+ orig_images = []
192
+ crops = []
193
+ for quad, (_, path) in tqdm(zip(quads, files), total=len(quads)):
194
+ crop = crop_image(path, IMAGE_SIZE, quad.copy())
195
+ orig_image = path # Image.open(path)
196
+ orig_images.append(orig_image)
197
+ crops.append(crop)
198
+ return crops, orig_images
199
+
200
+
201
+ def calc_alignment_coefficients(pa, pb):
202
+ matrix = []
203
+ for p1, p2 in zip(pa, pb):
204
+ matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
205
+ matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
206
+
207
+ a = np.matrix(matrix, dtype=float)
208
+ b = np.array(pb).reshape(8)
209
+
210
+ res = np.dot(np.linalg.inv(a.T * a) * a.T, b)
211
+ return np.array(res).reshape(8)
videoretalking/utils/audio.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ # import tensorflow as tf
5
+ from scipy import signal
6
+ from scipy.io import wavfile
7
+ from .hparams import hparams as hp
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ librosa.output.write_wav(path, wav, sr=sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ def get_hop_size():
31
+ hop_size = hp.hop_size
32
+ if hop_size is None:
33
+ assert hp.frame_shift_ms is not None
34
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35
+ return hop_size
36
+
37
+ def linearspectrogram(wav):
38
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40
+
41
+ if hp.signal_normalization:
42
+ return _normalize(S)
43
+ return S
44
+
45
+ def melspectrogram(wav):
46
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48
+
49
+ if hp.signal_normalization:
50
+ return _normalize(S)
51
+ return S
52
+
53
+ def _lws_processor():
54
+ import lws
55
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56
+
57
+ def _stft(y):
58
+ if hp.use_lws:
59
+ return _lws_processor(hp).stft(y).T
60
+ else:
61
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62
+
63
+ ##########################################################
64
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65
+ def num_frames(length, fsize, fshift):
66
+ """Compute number of time frames of spectrogram
67
+ """
68
+ pad = (fsize - fshift)
69
+ if length % fshift == 0:
70
+ M = (length + pad * 2 - fsize) // fshift + 1
71
+ else:
72
+ M = (length + pad * 2 - fsize) // fshift + 2
73
+ return M
74
+
75
+
76
+ def pad_lr(x, fsize, fshift):
77
+ """Compute left and right padding
78
+ """
79
+ M = num_frames(len(x), fsize, fshift)
80
+ pad = (fsize - fshift)
81
+ T = len(x) + 2 * pad
82
+ r = (M - 1) * fshift + fsize - T
83
+ return pad, pad + r
84
+ ##########################################################
85
+ #Librosa correct padding
86
+ def librosa_pad_lr(x, fsize, fshift):
87
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88
+
89
+ # Conversions
90
+ _mel_basis = None
91
+
92
+ def _linear_to_mel(spectogram):
93
+ global _mel_basis
94
+ if _mel_basis is None:
95
+ _mel_basis = _build_mel_basis()
96
+ return np.dot(_mel_basis, spectogram)
97
+
98
+ def _build_mel_basis():
99
+ assert hp.fmax <= hp.sample_rate // 2
100
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
101
+ fmin=hp.fmin, fmax=hp.fmax)
102
+
103
+ def _amp_to_db(x):
104
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105
+ return 20 * np.log10(np.maximum(min_level, x))
106
+
107
+ def _db_to_amp(x):
108
+ return np.power(10.0, (x) * 0.05)
109
+
110
+ def _normalize(S):
111
+ if hp.allow_clipping_in_normalization:
112
+ if hp.symmetric_mels:
113
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114
+ -hp.max_abs_value, hp.max_abs_value)
115
+ else:
116
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117
+
118
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119
+ if hp.symmetric_mels:
120
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121
+ else:
122
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123
+
124
+ def _denormalize(D):
125
+ if hp.allow_clipping_in_normalization:
126
+ if hp.symmetric_mels:
127
+ return (((np.clip(D, -hp.max_abs_value,
128
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129
+ + hp.min_level_db)
130
+ else:
131
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132
+
133
+ if hp.symmetric_mels:
134
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135
+ else:
136
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
videoretalking/utils/ffhq_preprocess.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import scipy
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+
14
+ """
15
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
16
+ author: lzhbrian (https://lzhbrian.me)
17
+ date: 2020.1.5
18
+ note: code is heavily borrowed from
19
+ https://github.com/NVlabs/ffhq-dataset
20
+ http://dlib.net/face_landmark_detection.py.html
21
+ requirements:
22
+ apt install cmake
23
+ conda install Pillow numpy scipy
24
+ pip install dlib
25
+ # download face landmark model from:
26
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
27
+ """
28
+
29
+ import numpy as np
30
+ from PIL import Image
31
+ import dlib
32
+
33
+
34
+ class Croper:
35
+ def __init__(self, path_of_lm):
36
+ # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
37
+ self.predictor = dlib.shape_predictor(path_of_lm)
38
+
39
+ def get_landmark(self, img_np):
40
+ """get landmark with dlib
41
+ :return: np.array shape=(68, 2)
42
+ """
43
+ detector = dlib.get_frontal_face_detector()
44
+ dets = detector(img_np, 1)
45
+ if len(dets) == 0:
46
+ return None
47
+ d = dets[0]
48
+ # Get the landmarks/parts for the face in box d.
49
+ shape = self.predictor(img_np, d)
50
+ t = list(shape.parts())
51
+ a = []
52
+ for tt in t:
53
+ a.append([tt.x, tt.y])
54
+ lm = np.array(a)
55
+ return lm
56
+
57
+ def align_face(self, img, lm, output_size=1024):
58
+ """
59
+ :param filepath: str
60
+ :return: PIL Image
61
+ """
62
+ lm_chin = lm[0: 17] # left-right
63
+ lm_eyebrow_left = lm[17: 22] # left-right
64
+ lm_eyebrow_right = lm[22: 27] # left-right
65
+ lm_nose = lm[27: 31] # top-down
66
+ lm_nostrils = lm[31: 36] # top-down
67
+ lm_eye_left = lm[36: 42] # left-clockwise
68
+ lm_eye_right = lm[42: 48] # left-clockwise
69
+ lm_mouth_outer = lm[48: 60] # left-clockwise
70
+ lm_mouth_inner = lm[60: 68] # left-clockwise
71
+
72
+ # Calculate auxiliary vectors.
73
+ eye_left = np.mean(lm_eye_left, axis=0)
74
+ eye_right = np.mean(lm_eye_right, axis=0)
75
+ eye_avg = (eye_left + eye_right) * 0.5
76
+ eye_to_eye = eye_right - eye_left
77
+ mouth_left = lm_mouth_outer[0]
78
+ mouth_right = lm_mouth_outer[6]
79
+ mouth_avg = (mouth_left + mouth_right) * 0.5
80
+ eye_to_mouth = mouth_avg - eye_avg
81
+
82
+ # Choose oriented crop rectangle.
83
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
84
+ x /= np.hypot(*x)
85
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
86
+ y = np.flipud(x) * [-1, 1]
87
+ c = eye_avg + eye_to_mouth * 0.1
88
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
89
+ qsize = np.hypot(*x) * 2
90
+
91
+ # Shrink.
92
+ shrink = int(np.floor(qsize / output_size * 0.5))
93
+ if shrink > 1:
94
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
95
+ img = img.resize(rsize, Image.ANTIALIAS)
96
+ quad /= shrink
97
+ qsize /= shrink
98
+
99
+ # Crop.
100
+ border = max(int(np.rint(qsize * 0.1)), 3)
101
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
102
+ int(np.ceil(max(quad[:, 1]))))
103
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
104
+ min(crop[3] + border, img.size[1]))
105
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
106
+ quad -= crop[0:2]
107
+
108
+ # Transform.
109
+ quad = (quad + 0.5).flatten()
110
+ lx = max(min(quad[0], quad[2]), 0)
111
+ ly = max(min(quad[1], quad[7]), 0)
112
+ rx = min(max(quad[4], quad[6]), img.size[0])
113
+ ry = min(max(quad[3], quad[5]), img.size[0])
114
+
115
+ # Save aligned image.
116
+ return crop, [lx, ly, rx, ry]
117
+
118
+ def crop(self, img_np_list, xsize=512): # first frame for all video
119
+ idx = 0
120
+ while idx < len(img_np_list)//2 : # TODO
121
+ img_np = img_np_list[idx]
122
+ lm = self.get_landmark(img_np)
123
+ if lm is not None:
124
+ break # can detect face
125
+ idx += 1
126
+ if lm is None:
127
+ return None
128
+
129
+ crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
130
+ clx, cly, crx, cry = crop
131
+ lx, ly, rx, ry = quad
132
+ lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
133
+ for _i in range(len(img_np_list)):
134
+ _inp = img_np_list[_i]
135
+ _inp = _inp[cly:cry, clx:crx]
136
+ _inp = _inp[ly:ry, lx:rx]
137
+ img_np_list[_i] = _inp
138
+ return img_np_list, crop, quad
139
+
140
+
videoretalking/utils/flow_util.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def convert_flow_to_deformation(flow):
4
+ r"""convert flow fields to deformations.
5
+
6
+ Args:
7
+ flow (tensor): Flow field obtained by the model
8
+ Returns:
9
+ deformation (tensor): The deformation used for warping
10
+ """
11
+ b,c,h,w = flow.shape
12
+ flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
13
+ grid = make_coordinate_grid(flow)
14
+ deformation = grid + flow_norm.permute(0,2,3,1)
15
+ return deformation
16
+
17
+ def make_coordinate_grid(flow):
18
+ r"""obtain coordinate grid with the same size as the flow filed.
19
+
20
+ Args:
21
+ flow (tensor): Flow field obtained by the model
22
+ Returns:
23
+ grid (tensor): The grid with the same size as the input flow
24
+ """
25
+ b,c,h,w = flow.shape
26
+
27
+ x = torch.arange(w).to(flow)
28
+ y = torch.arange(h).to(flow)
29
+
30
+ x = (2 * (x / (w - 1)) - 1)
31
+ y = (2 * (y / (h - 1)) - 1)
32
+
33
+ yy = y.view(-1, 1).repeat(1, w)
34
+ xx = x.view(1, -1).repeat(h, 1)
35
+
36
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
37
+ meshed = meshed.expand(b, -1, -1, -1)
38
+ return meshed
39
+
40
+
41
+ def warp_image(source_image, deformation):
42
+ r"""warp the input image according to the deformation
43
+
44
+ Args:
45
+ source_image (tensor): source images to be warped
46
+ deformation (tensor): deformations used to warp the images; value in range (-1, 1)
47
+ Returns:
48
+ output (tensor): the warped images
49
+ """
50
+ _, h_old, w_old, _ = deformation.shape
51
+ _, _, h, w = source_image.shape
52
+ if h_old != h or w_old != w:
53
+ deformation = deformation.permute(0, 3, 1, 2)
54
+ deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
55
+ deformation = deformation.permute(0, 2, 3, 1)
56
+ return torch.nn.functional.grid_sample(source_image, deformation)
videoretalking/utils/hparams.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class HParams:
4
+ def __init__(self, **kwargs):
5
+ self.data = {}
6
+
7
+ for key, value in kwargs.items():
8
+ self.data[key] = value
9
+
10
+ def __getattr__(self, key):
11
+ if key not in self.data:
12
+ raise AttributeError("'HParams' object has no attribute %s" % key)
13
+ return self.data[key]
14
+
15
+ def set_hparam(self, key, value):
16
+ self.data[key] = value
17
+
18
+
19
+ # Default hyperparameters
20
+ hparams = HParams(
21
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
22
+ # network
23
+ rescale=True, # Whether to rescale audio prior to preprocessing
24
+ rescaling_max=0.9, # Rescaling value
25
+
26
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
27
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
28
+ # Does not work if n_ffit is not multiple of hop_size!!
29
+ use_lws=False,
30
+
31
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
32
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
33
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
34
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
35
+
36
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
37
+
38
+ # Mel and Linear spectrograms normalization/scaling and clipping
39
+ signal_normalization=True,
40
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
41
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
42
+ symmetric_mels=True,
43
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
44
+ # faster and cleaner convergence)
45
+ max_abs_value=4.,
46
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
47
+ # be too big to avoid gradient explosion,
48
+ # not too small for fast convergence)
49
+ # Contribution by @begeekmyfriend
50
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
51
+ # levels. Also allows for better G&L phase reconstruction)
52
+ preemphasize=True, # whether to apply filter
53
+ preemphasis=0.97, # filter coefficient.
54
+
55
+ # Limits
56
+ min_level_db=-100,
57
+ ref_level_db=20,
58
+ fmin=55,
59
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
60
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
61
+ fmax=7600, # To be increased/reduced depending on data.
62
+
63
+ ###################### Our training parameters #################################
64
+ img_size=96,
65
+ fps=25,
66
+
67
+ batch_size=8,
68
+ initial_learning_rate=1e-4,
69
+ nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
70
+ num_workers=20,
71
+ checkpoint_interval=3000,
72
+ eval_interval=3000,
73
+ writer_interval=300,
74
+ save_optimizer_state=True,
75
+
76
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
77
+ syncnet_batch_size=64,
78
+ syncnet_lr=1e-4,
79
+ syncnet_eval_interval=10000,
80
+ syncnet_checkpoint_interval=10000,
81
+
82
+ disc_wt=0.07,
83
+ disc_initial_learning_rate=1e-4,
84
+ )
85
+
86
+
87
+
88
+ # Default hyperparameters
89
+ hparamsdebug = HParams(
90
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
91
+ # network
92
+ rescale=True, # Whether to rescale audio prior to preprocessing
93
+ rescaling_max=0.9, # Rescaling value
94
+
95
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
96
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
97
+ # Does not work if n_ffit is not multiple of hop_size!!
98
+ use_lws=False,
99
+
100
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
101
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
102
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
103
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
104
+
105
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
106
+
107
+ # Mel and Linear spectrograms normalization/scaling and clipping
108
+ signal_normalization=True,
109
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
110
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
111
+ symmetric_mels=True,
112
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
113
+ # faster and cleaner convergence)
114
+ max_abs_value=4.,
115
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
116
+ # be too big to avoid gradient explosion,
117
+ # not too small for fast convergence)
118
+ # Contribution by @begeekmyfriend
119
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
120
+ # levels. Also allows for better G&L phase reconstruction)
121
+ preemphasize=True, # whether to apply filter
122
+ preemphasis=0.97, # filter coefficient.
123
+
124
+ # Limits
125
+ min_level_db=-100,
126
+ ref_level_db=20,
127
+ fmin=55,
128
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
129
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
130
+ fmax=7600, # To be increased/reduced depending on data.
131
+ )
132
+
133
+
134
+ def hparams_debug_string():
135
+ values = hparams.values()
136
+ hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
137
+ return "Hyperparameters:\n" + "\n".join(hp)
videoretalking/utils/inference_utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2, argparse, torch
3
+ import torchvision.transforms.functional as TF
4
+
5
+ from models import load_network, load_DNet
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ from scipy.spatial import ConvexHull
9
+ from third_part import face_detection
10
+ from third_part.face3d.models import networks
11
+
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ def options():
16
+ parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
17
+
18
+ parser.add_argument('--DNet_path', type=str, default='checkpoints/DNet.pt')
19
+ parser.add_argument('--LNet_path', type=str, default='checkpoints/LNet.pth')
20
+ parser.add_argument('--ENet_path', type=str, default='checkpoints/ENet.pth')
21
+ parser.add_argument('--face3d_net_path', type=str, default='checkpoints/face3d_pretrain_epoch_20.pth')
22
+ parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', required=True)
23
+ parser.add_argument('--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True)
24
+ parser.add_argument('--exp_img', type=str, help='Expression template. neutral, smile or image path', default='neutral')
25
+ parser.add_argument('--outfile', type=str, help='Video path to save result')
26
+
27
+ parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', default=25., required=False)
28
+ parser.add_argument('--pads', nargs='+', type=int, default=[0, 20, 0, 0], help='Padding (top, bottom, left, right). Please adjust to include chin at least')
29
+ parser.add_argument('--face_det_batch_size', type=int, help='Batch size for face detection', default=4)
30
+ parser.add_argument('--LNet_batch_size', type=int, help='Batch size for LNet', default=16)
31
+ parser.add_argument('--img_size', type=int, default=384)
32
+ parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
33
+ help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
34
+ 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')
35
+ parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
36
+ help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
37
+ 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')
38
+ parser.add_argument('--nosmooth', default=False, action='store_true', help='Prevent smoothing face detections over a short temporal window')
39
+ parser.add_argument('--static', default=False, action='store_true')
40
+
41
+
42
+ parser.add_argument('--up_face', default='original')
43
+ parser.add_argument('--one_shot', action='store_true')
44
+ parser.add_argument('--without_rl1', default=False, action='store_true', help='Do not use the relative l1')
45
+ parser.add_argument('--tmp_dir', type=str, default='temp', help='Folder to save tmp results')
46
+ parser.add_argument('--re_preprocess', action='store_true')
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+ exp_aus_dict = { # AU01_r, AU02_r, AU04_r, AU05_r, AU06_r, AU07_r, AU09_r, AU10_r, AU12_r, AU14_r, AU15_r, AU17_r, AU20_r, AU23_r, AU25_r, AU26_r, AU45_r.
52
+ 'sad': torch.Tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
53
+ 'angry':torch.Tensor([[0, 0, 0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
54
+ 'surprise': torch.Tensor([[0, 0, 0, 0.2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
55
+ }
56
+
57
+ def mask_postprocess(mask, thres=20):
58
+ mask[:thres, :] = 0; mask[-thres:, :] = 0
59
+ mask[:, :thres] = 0; mask[:, -thres:] = 0
60
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
61
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
62
+ return mask.astype(np.float32)
63
+
64
+ def trans_image(image):
65
+ image = TF.resize(
66
+ image, size=256, interpolation=Image.BICUBIC)
67
+ image = TF.to_tensor(image)
68
+ image = TF.normalize(image, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
69
+ return image
70
+
71
+ def obtain_seq_index(index, num_frames):
72
+ seq = list(range(index-13, index+13))
73
+ seq = [ min(max(item, 0), num_frames-1) for item in seq ]
74
+ return seq
75
+
76
+ def transform_semantic(semantic, frame_index, crop_norm_ratio=None):
77
+ index = obtain_seq_index(frame_index, semantic.shape[0])
78
+
79
+ coeff_3dmm = semantic[index,...]
80
+ ex_coeff = coeff_3dmm[:,80:144] #expression # 64
81
+ angles = coeff_3dmm[:,224:227] #euler angles for pose
82
+ translation = coeff_3dmm[:,254:257] #translation
83
+ crop = coeff_3dmm[:,259:262] #crop param
84
+
85
+ if crop_norm_ratio:
86
+ crop[:, -3] = crop[:, -3] * crop_norm_ratio
87
+
88
+ coeff_3dmm = np.concatenate([ex_coeff, angles, translation, crop], 1)
89
+ return torch.Tensor(coeff_3dmm).permute(1,0)
90
+
91
+ def find_crop_norm_ratio(source_coeff, target_coeffs):
92
+ alpha = 0.3
93
+ exp_diff = np.mean(np.abs(target_coeffs[:,80:144] - source_coeff[:,80:144]), 1) # mean different exp
94
+ angle_diff = np.mean(np.abs(target_coeffs[:,224:227] - source_coeff[:,224:227]), 1) # mean different angle
95
+ index = np.argmin(alpha*exp_diff + (1-alpha)*angle_diff) # find the smallerest index
96
+ crop_norm_ratio = source_coeff[:,-3] / target_coeffs[index:index+1, -3]
97
+ return crop_norm_ratio
98
+
99
+ def get_smoothened_boxes(boxes, T):
100
+ for i in range(len(boxes)):
101
+ if i + T > len(boxes):
102
+ window = boxes[len(boxes) - T:]
103
+ else:
104
+ window = boxes[i : i + T]
105
+ boxes[i] = np.mean(window, axis=0)
106
+ return boxes
107
+
108
+ def face_detect(images, face_det_batch_size, nosmooth, pads, jaw_correction, detector=None):
109
+ # def face_detect(images, args, jaw_correction=False, detector=None):
110
+ if detector == None:
111
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
112
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
113
+ flip_input=False, device=device)
114
+
115
+ batch_size = face_det_batch_size
116
+ while 1:
117
+ predictions = []
118
+ try:
119
+ for i in tqdm(range(0, len(images), batch_size),desc='FaceDet:'):
120
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
121
+ except RuntimeError:
122
+ if batch_size == 1:
123
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
124
+ batch_size //= 2
125
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
126
+ continue
127
+ break
128
+
129
+ results = []
130
+ pady1, pady2, padx1, padx2 = pads if jaw_correction else (0,20,0,0)
131
+ for rect, image in zip(predictions, images):
132
+ if rect is None:
133
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
134
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
135
+
136
+ y1 = max(0, rect[1] - pady1)
137
+ y2 = min(image.shape[0], rect[3] + pady2)
138
+ x1 = max(0, rect[0] - padx1)
139
+ x2 = min(image.shape[1], rect[2] + padx2)
140
+ results.append([x1, y1, x2, y2])
141
+
142
+ boxes = np.array(results)
143
+ if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
144
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
145
+
146
+ del detector
147
+ torch.cuda.empty_cache()
148
+ return results
149
+
150
+ def _load(checkpoint_path, device):
151
+ if device == 'cuda':
152
+ checkpoint = torch.load(checkpoint_path)
153
+ else:
154
+ checkpoint = torch.load(checkpoint_path,
155
+ map_location=lambda storage, loc: storage)
156
+ return checkpoint
157
+
158
+ def split_coeff(coeffs):
159
+ """
160
+ Return:
161
+ coeffs_dict -- a dict of torch.tensors
162
+
163
+ Parameters:
164
+ coeffs -- torch.tensor, size (B, 256)
165
+ """
166
+ id_coeffs = coeffs[:, :80]
167
+ exp_coeffs = coeffs[:, 80: 144]
168
+ tex_coeffs = coeffs[:, 144: 224]
169
+ angles = coeffs[:, 224: 227]
170
+ gammas = coeffs[:, 227: 254]
171
+ translations = coeffs[:, 254:]
172
+ return {
173
+ 'id': id_coeffs,
174
+ 'exp': exp_coeffs,
175
+ 'tex': tex_coeffs,
176
+ 'angle': angles,
177
+ 'gamma': gammas,
178
+ 'trans': translations
179
+ }
180
+
181
+ def Laplacian_Pyramid_Blending_with_mask(A, B, m, num_levels = 6):
182
+ # generate Gaussian pyramid for A,B and mask
183
+ GA = A.copy()
184
+ GB = B.copy()
185
+ GM = m.copy()
186
+ gpA = [GA]
187
+ gpB = [GB]
188
+ gpM = [GM]
189
+ for i in range(num_levels):
190
+ GA = cv2.pyrDown(GA)
191
+ GB = cv2.pyrDown(GB)
192
+ GM = cv2.pyrDown(GM)
193
+ gpA.append(np.float32(GA))
194
+ gpB.append(np.float32(GB))
195
+ gpM.append(np.float32(GM))
196
+
197
+ # generate Laplacian Pyramids for A,B and masks
198
+ lpA = [gpA[num_levels-1]] # the bottom of the Lap-pyr holds the last (smallest) Gauss level
199
+ lpB = [gpB[num_levels-1]]
200
+ gpMr = [gpM[num_levels-1]]
201
+ for i in range(num_levels-1,0,-1):
202
+ # Laplacian: subtract upscaled version of lower level from current level
203
+ # to get the high frequencies
204
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
205
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
206
+ lpA.append(LA)
207
+ lpB.append(LB)
208
+ gpMr.append(gpM[i-1]) # also reverse the masks
209
+
210
+ # Now blend images according to mask in each level
211
+ LS = []
212
+ for la,lb,gm in zip(lpA,lpB,gpMr):
213
+ gm = gm[:,:,np.newaxis]
214
+ ls = la * gm + lb * (1.0 - gm)
215
+ LS.append(ls)
216
+
217
+ # now reconstruct
218
+ ls_ = LS[0]
219
+ for i in range(1,num_levels):
220
+ ls_ = cv2.pyrUp(ls_)
221
+ ls_ = cv2.add(ls_, LS[i])
222
+ return ls_
223
+
224
+ def load_model(device,DNet_path,LNet_path,ENet_path):
225
+ D_Net = load_DNet(DNet_path).to(device)
226
+ model = load_network(LNet_path,ENet_path).to(device)
227
+ return D_Net, model
228
+
229
+ def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
230
+ use_relative_movement=False, use_relative_jacobian=False):
231
+ if adapt_movement_scale:
232
+ source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
233
+ driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
234
+ adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
235
+ else:
236
+ adapt_movement_scale = 1
237
+
238
+ kp_new = {k: v for k, v in kp_driving.items()}
239
+ if use_relative_movement:
240
+ kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
241
+ kp_value_diff *= adapt_movement_scale
242
+ kp_new['value'] = kp_value_diff + kp_source['value']
243
+
244
+ if use_relative_jacobian:
245
+ jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
246
+ kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
247
+ return kp_new
248
+
249
+ def load_face3d_net(ckpt_path, device):
250
+ net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
251
+ checkpoint = torch.load(ckpt_path, map_location=device)
252
+ net_recon.load_state_dict(checkpoint['net_recon'])
253
+ net_recon.eval()
254
+ return net_recon