|
|
|
|
|
""" |
|
functions for processing and transforming 3D facial keypoints |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
PI = np.pi |
|
|
|
|
|
def headpose_pred_to_degree(pred): |
|
""" |
|
pred: (bs, 66) or (bs, 1) or others |
|
""" |
|
if pred.ndim > 1 and pred.shape[1] == 66: |
|
|
|
device = pred.device |
|
idx_tensor = [idx for idx in range(0, 66)] |
|
idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
|
pred = F.softmax(pred, dim=1) |
|
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 |
|
|
|
return degree |
|
|
|
return pred |
|
|
|
|
|
def get_rotation_matrix(pitch_, yaw_, roll_): |
|
""" the input is in degree |
|
""" |
|
|
|
|
|
|
|
pitch = pitch_ / 180 * PI |
|
yaw = yaw_ / 180 * PI |
|
roll = roll_ / 180 * PI |
|
|
|
device = pitch.device |
|
|
|
if pitch.ndim == 1: |
|
pitch = pitch.unsqueeze(1) |
|
if yaw.ndim == 1: |
|
yaw = yaw.unsqueeze(1) |
|
if roll.ndim == 1: |
|
roll = roll.unsqueeze(1) |
|
|
|
|
|
bs = pitch.shape[0] |
|
ones = torch.ones([bs, 1]).to(device) |
|
zeros = torch.zeros([bs, 1]).to(device) |
|
x, y, z = pitch, yaw, roll |
|
|
|
rot_x = torch.cat([ |
|
ones, zeros, zeros, |
|
zeros, torch.cos(x), -torch.sin(x), |
|
zeros, torch.sin(x), torch.cos(x) |
|
], dim=1).reshape([bs, 3, 3]) |
|
|
|
rot_y = torch.cat([ |
|
torch.cos(y), zeros, torch.sin(y), |
|
zeros, ones, zeros, |
|
-torch.sin(y), zeros, torch.cos(y) |
|
], dim=1).reshape([bs, 3, 3]) |
|
|
|
rot_z = torch.cat([ |
|
torch.cos(z), -torch.sin(z), zeros, |
|
torch.sin(z), torch.cos(z), zeros, |
|
zeros, zeros, ones |
|
], dim=1).reshape([bs, 3, 3]) |
|
|
|
rot = rot_z @ rot_y @ rot_x |
|
return rot.permute(0, 2, 1) |
|
|