File size: 4,887 Bytes
c3d3e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
 
 
 
 
 
 
 
c3d3e4a
 
 
 
 
 
 
 
 
 
 
fb140f6
 
c3d3e4a
 
 
 
 
 
fb140f6
c3d3e4a
 
 
 
 
fb140f6
 
c3d3e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb140f6
 
 
 
c3d3e4a
fb140f6
 
 
 
 
c3d3e4a
 
 
 
 
 
 
 
 
 
 
fb140f6
c3d3e4a
 
 
 
 
fb140f6
c3d3e4a
 
 
 
 
 
fb140f6
 
 
c3d3e4a
 
 
 
fb140f6
c3d3e4a
 
 
 
 
 
 
fb140f6
c3d3e4a
 
 
 
 
 
 
 
 
 
 
fb140f6
 
c3d3e4a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
# All rights reserved.
# This file is part of the pytorch-nicp,
# and is released under the "MIT License Agreement". Please see the LICENSE
# file that should have been included as part of this package.

import torch
import trimesh
import torch.nn as nn
from tqdm import tqdm
from pytorch3d.structures import Meshes
from pytorch3d.loss import chamfer_distance
from lib.dataset.mesh_util import update_mesh_shape_prior_losses
from lib.common.train_util import init_loss


# reference: https://github.com/wuhaozhe/pytorch-nicp
class LocalAffine(nn.Module):
    def __init__(self, num_points, batch_size=1, edges=None):
        '''
            specify the number of points, the number of points should be constant across the batch
            and the edges torch.Longtensor() with shape N * 2
            the local affine operator supports batch operation
            batch size must be constant
            add additional pooling on top of w matrix
        '''
        super(LocalAffine, self).__init__()
        self.A = nn.Parameter(
            torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
        )
        self.b = nn.Parameter(
            torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
                batch_size, num_points, 1, 1
            )
        )
        self.edges = edges
        self.num_points = num_points

    def stiffness(self):
        '''
            calculate the stiffness of local affine transformation
            f norm get infinity gradient when w is zero matrix, 
        '''
        if self.edges is None:
            raise Exception("edges cannot be none when calculate stiff")
        affine_weight = torch.cat((self.A, self.b), dim=3)
        w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
        w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
        w_diff = (w1 - w2)**2
        w_rigid = (torch.linalg.det(self.A) - 1.0)**2
        return w_diff, w_rigid

    def forward(self, x):
        '''
            x should have shape of B * N * 3 * 1
        '''
        x = x.unsqueeze(3)
        out_x = torch.matmul(self.A, x)
        out_x = out_x + self.b
        out_x.squeeze_(3)
        stiffness, rigid = self.stiffness()

        return out_x, stiffness, rigid


def trimesh2meshes(mesh):
    '''
        convert trimesh mesh to pytorch3d mesh
    '''
    verts = torch.from_numpy(mesh.vertices).float()
    faces = torch.from_numpy(mesh.faces).long()
    mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0))
    return mesh


def register(target_mesh, src_mesh, device):

    # define local_affine deform verts
    tgt_mesh = trimesh2meshes(target_mesh).to(device)
    src_verts = src_mesh.verts_padded().clone()

    local_affine_model = LocalAffine(
        src_mesh.verts_padded().shape[1],
        src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
    ).to(device)

    optimizer_cloth = torch.optim.Adam(
        [{
            'params': local_affine_model.parameters()
        }], lr=1e-2, amsgrad=True
    )
    scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_cloth,
        mode="min",
        factor=0.1,
        verbose=0,
        min_lr=1e-5,
        patience=5,
    )

    losses = init_loss()

    loop_cloth = tqdm(range(100))

    for i in loop_cloth:

        optimizer_cloth.zero_grad()

        deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
        src_mesh = src_mesh.update_padded(deformed_verts)

        # losses for laplacian, edge, normal consistency
        update_mesh_shape_prior_losses(src_mesh, losses)

        losses["cloth"]["value"] = chamfer_distance(
            x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
        )[0]
        losses["stiff"]["value"] = torch.mean(stiffness)
        losses["rigid"]["value"] = torch.mean(rigid)

        # Weighted sum of the losses
        cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
        pbar_desc = "Register SMPL-X -> d-BiNI -- "

        for k in losses.keys():
            if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
                cloth_loss = cloth_loss + \
                    losses[k]["value"] * losses[k]["weight"]
                pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "

        pbar_desc += f"TOTAL: {cloth_loss:.3f}"
        loop_cloth.set_description(pbar_desc)

        # update params
        cloth_loss.backward(retain_graph=True)
        optimizer_cloth.step()
        scheduler_cloth.step(cloth_loss)

    final = trimesh.Trimesh(
        src_mesh.verts_packed().detach().squeeze(0).cpu(),
        src_mesh.faces_packed().detach().squeeze(0).cpu(),
        process=False,
        maintains_order=True
    )

    return final