Spaces:
Running
Running
import itertools | |
import torch | |
import torch.nn as nn | |
import pose_estimation | |
class MSE(nn.Module): | |
def __init__(self, ignore=None): | |
super().__init__() | |
self.mse = torch.nn.MSELoss(reduction="none") | |
self.ignore = ignore if ignore is not None else [] | |
def forward(self, y_pred, y_data): | |
loss = self.mse(y_pred, y_data) | |
if len(self.ignore) > 0: | |
loss[self.ignore] *= 0 | |
return loss.sum() / (len(loss) - len(self.ignore)) | |
class Parallel(nn.Module): | |
def __init__(self, skeleton, ignore=None, ground_parallel=None): | |
super().__init__() | |
self.skeleton = skeleton | |
if ignore is not None: | |
self.ignore = set(ignore) | |
else: | |
self.ignore = set() | |
self.ground_parallel = ground_parallel if ground_parallel is not None else [] | |
self.parallel_in_3d = [] | |
self.cos = None | |
def forward(self, y_pred3d, y_data, z, spine_j, global_step=0): | |
y_pred = y_pred3d[:, :2] | |
rleg, lleg = spine_j | |
Lcon2d = Lcount = 0 | |
if hasattr(self, "contact_2d"): | |
for c2d in self.contact_2d: | |
for ( | |
(src_1, dst_1, t_1), | |
(src_2, dst_2, t_2), | |
) in itertools.combinations(c2d, 2): | |
a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1) | |
a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2) | |
a = a_2 - a_1 | |
b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1) | |
b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2) | |
b = b_2 - b_1 | |
lcon2d = ((a - b) ** 2).sum() | |
Lcon2d = Lcon2d + lcon2d | |
Lcount += 1 | |
if Lcount > 0: | |
Lcon2d = Lcon2d / Lcount | |
Ltan = Lpar = Lcos = Lcount = 0 | |
Lspine = 0 | |
for i, bone in enumerate(self.skeleton): | |
if bone in self.ignore: | |
continue | |
src, dst = bone | |
b = y_data[dst] - y_data[src] | |
t = nn.functional.normalize(b, dim=0) | |
n = torch.stack([-t[1], t[0]]) | |
if src == 10 and dst == 11: # right leg | |
a = rleg | |
elif src == 13 and dst == 14: # left leg | |
a = lleg | |
else: | |
a = y_pred[dst] - y_pred[src] | |
bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}" | |
c = a - b | |
lcos_loc = ltan_loc = lpar_loc = 0 | |
if self.cos is not None: | |
if bone not in [ | |
(1, 2), # Neck + Right Shoulder | |
(1, 5), # Neck + Left Shoulder | |
(9, 10), # Hips + Right Upper Leg | |
(9, 13), # Hips + Left Upper Leg | |
]: | |
a = y_pred[dst] - y_pred[src] | |
l2d = torch.norm(a, dim=0) | |
l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0) | |
lcos = self.cos[i] | |
lcos_loc = (l2d / l3d - lcos) ** 2 | |
Lcos = Lcos + lcos_loc | |
lpar_loc = ((a / l2d) * n).sum() ** 2 | |
Lpar = Lpar + lpar_loc | |
else: | |
ltan_loc = ((c * t).sum()) ** 2 | |
Ltan = Ltan + ltan_loc | |
lpar_loc = (c * n).sum() ** 2 | |
Lpar = Lpar + lpar_loc | |
Lcount += 1 | |
if Lcount > 0: | |
Ltan = Ltan / Lcount | |
Lcos = Lcos / Lcount | |
Lpar = Lpar / Lcount | |
Lspine = Lspine / Lcount | |
Lgr = Lcount = 0 | |
for (src, dst), value in self.ground_parallel: | |
bone = y_pred[dst] - y_pred[src] | |
bone = nn.functional.normalize(bone, dim=0) | |
l = (torch.abs(bone[0]) - value) ** 2 | |
Lgr = Lgr + l | |
Lcount += 1 | |
if Lcount > 0: | |
Lgr = Lgr / Lcount | |
Lstraight3d = Lcount = 0 | |
for (i, j), (k, l) in self.parallel_in_3d: | |
a = z[j] - z[i] | |
a = nn.functional.normalize(a, dim=0) | |
b = z[l] - z[k] | |
b = nn.functional.normalize(b, dim=0) | |
lo = (((a * b).sum() - 1) ** 2).sum() | |
Lstraight3d = Lstraight3d + lo | |
Lcount += 1 | |
b = y_data[1] - y_data[8] | |
b = nn.functional.normalize(b, dim=0) | |
if Lcount > 0: | |
Lstraight3d = Lstraight3d / Lcount | |
return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d | |
class MimickedSelfContactLoss(nn.Module): | |
def __init__(self, geodesics_mask): | |
super().__init__() | |
""" | |
Loss that lets vertices in contact on presented mesh attract vertices that are close. | |
""" | |
# geodesic distance mask | |
self.register_buffer("geomask", geodesics_mask) | |
def forward( | |
self, | |
presented_contact, | |
vertices, | |
v2v=None, | |
contact_mode="dist_tanh", | |
contact_thresh=1, | |
): | |
contactloss = 0.0 | |
if v2v is None: | |
# compute pairwise distances | |
verts = vertices.contiguous() | |
nv = verts.shape[1] | |
v2v = verts.squeeze().unsqueeze(1).expand( | |
nv, nv, 3 | |
) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3) | |
v2v = torch.norm(v2v, 2, 2) | |
# loss for self-contact from mimic'ed pose | |
if len(presented_contact) > 0: | |
# without geodesic distance mask, compute distances | |
# between each pair of verts in contact | |
with torch.no_grad(): | |
cvertstobody = v2v[presented_contact, :] | |
cvertstobody = cvertstobody[:, presented_contact] | |
maskgeo = self.geomask[presented_contact, :] | |
maskgeo = maskgeo[:, presented_contact] | |
weights = torch.ones_like(cvertstobody).to(verts.device) | |
weights[~maskgeo] = float("inf") | |
min_idx = torch.min((cvertstobody + 1) * weights, 1)[1] | |
min_idx = presented_contact[min_idx.cpu().numpy()] | |
v2v_min = v2v[presented_contact, min_idx] | |
# tanh will not pull vertices that are ~more than contact_thres far apart | |
if contact_mode == "dist_tanh": | |
contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh) | |
contactloss = contactloss.mean() | |
else: | |
contactloss = v2v_min.mean() | |
return contactloss | |