Spanicin's picture
Upload 77 files
5c012bf verified
raw
history blame
4.01 kB
import os
import glob
import numpy as np
from os import makedirs, name
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from face3d.options.inference_options import InferenceOptions
from face3d.models import create_model
from face3d.util.preprocess import align_img
from face3d.util.load_mats import load_lm3d
from face3d.extract_kp_videos import KeypointExtractor
class CoeffDetector(nn.Module):
def __init__(self, opt):
super().__init__()
self.model = create_model(opt)
self.model.setup(opt)
self.model.device = 'cuda'
self.model.parallelize()
self.model.eval()
self.lm3d_std = load_lm3d(opt.bfm_folder)
def forward(self, img, lm):
img, trans_params = self.image_transform(img, lm)
data_input = {
'imgs': img[None],
}
self.model.set_input(data_input)
self.model.test()
pred_coeff = {key:self.model.pred_coeffs_dict[key].cpu().numpy() for key in self.model.pred_coeffs_dict}
pred_coeff = np.concatenate([
pred_coeff['id'],
pred_coeff['exp'],
pred_coeff['tex'],
pred_coeff['angle'],
pred_coeff['gamma'],
pred_coeff['trans'],
trans_params[None],
], 1)
return {'coeff_3dmm':pred_coeff,
'crop_img': Image.fromarray((img.cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8))}
def image_transform(self, images, lm):
"""
param:
images: -- PIL image
lm: -- numpy array
"""
W,H = images.size
if np.mean(lm) == -1:
lm = (self.lm3d_std[:, :2]+1)/2.
lm = np.concatenate(
[lm[:, :1]*W, lm[:, 1:2]*H], 1
)
else:
lm[:, -1] = H - 1 - lm[:, -1]
trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std)
img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1)
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)])
trans_params = torch.tensor(trans_params.astype(np.float32))
return img, trans_params
def get_data_path(root, keypoint_root):
filenames = list()
keypoint_filenames = list()
IMAGE_EXTENSIONS_LOWERCASE = {'jpg', 'png', 'jpeg', 'webp'}
IMAGE_EXTENSIONS = IMAGE_EXTENSIONS_LOWERCASE.union({f.upper() for f in IMAGE_EXTENSIONS_LOWERCASE})
extensions = IMAGE_EXTENSIONS
for ext in extensions:
filenames += glob.glob(f'{root}/*.{ext}', recursive=True)
filenames = sorted(filenames)
for filename in filenames:
name = os.path.splitext(os.path.basename(filename))[0]
keypoint_filenames.append(
os.path.join(keypoint_root, name + '.txt')
)
return filenames, keypoint_filenames
if __name__ == "__main__":
opt = InferenceOptions().parse()
coeff_detector = CoeffDetector(opt)
kp_extractor = KeypointExtractor()
image_names, keypoint_names = get_data_path(opt.input_dir, opt.keypoint_dir)
makedirs(opt.keypoint_dir, exist_ok=True)
makedirs(opt.output_dir, exist_ok=True)
for image_name, keypoint_name in tqdm(zip(image_names, keypoint_names)):
image = Image.open(image_name)
if not os.path.isfile(keypoint_name):
lm = kp_extractor.extract_keypoint(image, keypoint_name)
else:
lm = np.loadtxt(keypoint_name).astype(np.float32)
lm = lm.reshape([-1, 2])
predicted = coeff_detector(image, lm)
name = os.path.splitext(os.path.basename(image_name))[0]
np.savetxt(
"{}/{}_3dmm_coeff.txt".format(opt.output_dir, name),
predicted['coeff_3dmm'].reshape(-1))