Spaces:
Paused
Paused
import os | |
import glob | |
import numpy as np | |
from os import makedirs, name | |
from PIL import Image | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
from face3d.options.inference_options import InferenceOptions | |
from face3d.models import create_model | |
from face3d.util.preprocess import align_img | |
from face3d.util.load_mats import load_lm3d | |
from face3d.extract_kp_videos import KeypointExtractor | |
class CoeffDetector(nn.Module): | |
def __init__(self, opt): | |
super().__init__() | |
self.model = create_model(opt) | |
self.model.setup(opt) | |
self.model.device = 'cuda' | |
self.model.parallelize() | |
self.model.eval() | |
self.lm3d_std = load_lm3d(opt.bfm_folder) | |
def forward(self, img, lm): | |
img, trans_params = self.image_transform(img, lm) | |
data_input = { | |
'imgs': img[None], | |
} | |
self.model.set_input(data_input) | |
self.model.test() | |
pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict} | |
pred_coeff = np.concatenate([ | |
pred_coeff['id'], | |
pred_coeff['exp'], | |
pred_coeff['tex'], | |
pred_coeff['angle'], | |
pred_coeff['gamma'], | |
pred_coeff['trans'], | |
trans_params[None], | |
], 1) | |
return {'coeff_3dmm':pred_coeff, | |
'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))} | |
def image_transform(self, images, lm): | |
""" | |
param: | |
images: -- PIL image | |
lm: -- numpy array | |
""" | |
W,H = images.size | |
if np.mean(lm) == -1: | |
lm = (self.lm3d_std[:, :2]+1)/2. | |
lm = np.concatenate( | |
[lm[:, :1]*W, lm[:, 1:2]*H], 1 | |
) | |
else: | |
lm[:, -1] = H - 1 - lm[:, -1] | |
trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) | |
img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) | |
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) | |
trans_params = torch.tensor(trans_params.astype(np.float32)) | |
return img, trans_params | |
def get_data_path(root, keypoint_root): | |
filenames = list() | |
keypoint_filenames = list() | |
IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'} | |
IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE}) | |
extensions = IMAGE_EXTENSIONS | |
for ext in extensions: | |
filenames += glob.glob(f'{root}/*.{ext}', recursive=True) | |
filenames = sorted(filenames) | |
for filename in filenames: | |
name = os.path.splitext(os.path.basename(filename))[0] | |
keypoint_filenames.append( | |
os.path.join(keypoint_root, name + '.txt') | |
) | |
return filenames, keypoint_filenames | |
if __name__ == "__main__": | |
opt = InferenceOptions().parse() | |
coeff_detector = CoeffDetector(opt) | |
kp_extractor = KeypointExtractor() | |
image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir) | |
makedirs(opt.keypoint_dir, exist_ok=True) | |
makedirs(opt.output_dir, exist_ok=True) | |
for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)): | |
image = Image.open(image_name) | |
if not os.path.isfile(keypoint_name): | |
lm = kp_extractor.extract_keypoint(image, keypoint_name) | |
else: | |
lm = np.loadtxt(keypoint_name).astype(np.float32) | |
lm = lm.reshape([-1, 2]) | |
predicted = coeff_detector(image, lm) | |
name = os.path.splitext(os.path.basename(image_name))[0] | |
np.savetxt( | |
"{}/{}_3dmm_coeff.txt".format(opt.output_dir, name), | |
predicted['coeff_3dmm'].reshape(-1)) | |