Spaces:
Paused
Paused
import os | |
import cv2 | |
import glob | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from scipy.io import savemat | |
import torch | |
from models import create_model | |
from options.inference_options import InferenceOptions | |
from util.preprocess import align_img | |
from util.load_mats import load_lm3d | |
from util.util import mkdirs, tensor2im, save_image | |
def get_data_path(root, keypoint_root): | |
filenames = list() | |
keypoint_filenames = list() | |
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} | |
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) | |
extensions = VIDEO_EXTENSIONS | |
for ext in extensions: | |
filenames += glob.glob(f'{root}/**/*.{ext}', recursive=True) | |
filenames = sorted(filenames) | |
keypoint_filenames = sorted(glob.glob(f'{keypoint_root}/**/*.txt', recursive=True)) | |
assert len(filenames) == len(keypoint_filenames) | |
return filenames, keypoint_filenames | |
class VideoPathDataset(torch.utils.data.Dataset): | |
def __init__(self, filenames, txt_filenames, bfm_folder): | |
self.filenames = filenames | |
self.txt_filenames = txt_filenames | |
self.lm3d_std = load_lm3d(bfm_folder) | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, index): | |
filename = self.filenames[index] | |
txt_filename = self.txt_filenames[index] | |
frames = self.read_video(filename) | |
lm = np.loadtxt(txt_filename).astype(np.float32) | |
lm = lm.reshape([len(frames), -1, 2]) | |
out_images, out_trans_params = list(), list() | |
for i in range(len(frames)): | |
out_img, _, out_trans_param \ | |
= self.image_transform(frames[i], lm[i]) | |
out_images.append(out_img[None]) | |
out_trans_params.append(out_trans_param[None]) | |
return { | |
'imgs': torch.cat(out_images, 0), | |
'trans_param':torch.cat(out_trans_params, 0), | |
'filename': filename | |
} | |
def read_video(self, filename): | |
frames = list() | |
cap = cv2.VideoCapture(filename) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if ret: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = Image.fromarray(frame) | |
frames.append(frame) | |
else: | |
break | |
cap.release() | |
return frames | |
def image_transform(self, images, lm): | |
W,H = images.size | |
if np.mean(lm) == -1: | |
lm = (self.lm3d_std[:, :2]+1)/2. | |
lm = np.concatenate( | |
[lm[:, :1]*W, lm[:, 1:2]*H], 1 | |
) | |
else: | |
lm[:, -1] = H - 1 - lm[:, -1] | |
trans_params, img, lm, _ = align_img(images, lm, self.lm3d_std) | |
img = torch.tensor(np.array(img)/255., dtype=torch.float32).permute(2, 0, 1) | |
lm = torch.tensor(lm) | |
trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]) | |
trans_params = torch.tensor(trans_params.astype(np.float32)) | |
return img, lm, trans_params | |
def main(opt, model): | |
# import torch.multiprocessing | |
# torch.multiprocessing.set_sharing_strategy('file_system') | |
filenames, keypoint_filenames = get_data_path(opt.input_dir, opt.keypoint_dir) | |
dataset = VideoPathDataset(filenames, keypoint_filenames, opt.bfm_folder) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=1, # can noly set to one here! | |
shuffle=False, | |
drop_last=False, | |
num_workers=0, | |
) | |
batch_size = opt.inference_batch_size | |
for data in tqdm(dataloader): | |
num_batch = data['imgs'][0].shape[0] // batch_size + 1 | |
pred_coeffs = list() | |
for index in range(num_batch): | |
data_input = { | |
'imgs': data['imgs'][0,index*batch_size:(index+1)*batch_size], | |
} | |
model.set_input(data_input) | |
model.test() | |
pred_coeff = {key:model.pred_coeffs_dict[key].cpu().numpy() for key in model.pred_coeffs_dict} | |
pred_coeff = np.concatenate([ | |
pred_coeff['id'], | |
pred_coeff['exp'], | |
pred_coeff['tex'], | |
pred_coeff['angle'], | |
pred_coeff['gamma'], | |
pred_coeff['trans']], 1) | |
pred_coeffs.append(pred_coeff) | |
visuals = model.get_current_visuals() # get image results | |
if False: # debug | |
for name in visuals: | |
images = visuals[name] | |
for i in range(images.shape[0]): | |
image_numpy = tensor2im(images[i]) | |
save_image( | |
image_numpy, | |
os.path.join( | |
opt.output_dir, | |
os.path.basename(data['filename'][0])+str(i).zfill(5)+'.jpg') | |
) | |
exit() | |
pred_coeffs = np.concatenate(pred_coeffs, 0) | |
pred_trans_params = data['trans_param'][0].cpu().numpy() | |
name = data['filename'][0].split('/')[-2:] | |
name[-1] = os.path.splitext(name[-1])[0] + '.mat' | |
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) | |
savemat( | |
os.path.join(opt.output_dir, name[-2], name[-1]), | |
{'coeff':pred_coeffs, 'transform_params':pred_trans_params} | |
) | |
if __name__ == '__main__': | |
opt = InferenceOptions().parse() # get test options | |
model = create_model(opt) | |
model.setup(opt) | |
model.device = 'cuda:0' | |
model.parallelize() | |
model.eval() | |
main(opt, model) | |