kbrodt commited on
Commit
347ea73
1 Parent(s): 00074b2

Upload hmr.py

Browse files
Files changed (1) hide show
  1. src/spin/hmr.py +196 -0
src/spin/hmr.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.models.resnet as resnet
7
+
8
+
9
+ def rot6d_to_rotmat(x):
10
+ """Convert 6D rotation representation to 3x3 rotation matrix.
11
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
12
+ Input:
13
+ (B,6) Batch of 6-D rotation representations
14
+ Output:
15
+ (B,3,3) Batch of corresponding rotation matrices
16
+ """
17
+
18
+ x = x.view(-1, 3, 2)
19
+ a1 = x[:, :, 0]
20
+ a2 = x[:, :, 1]
21
+ b1 = nn.functional.normalize(a1)
22
+ b2 = nn.functional.normalize(
23
+ a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1
24
+ )
25
+
26
+ b3 = torch.cross(b1, b2)
27
+
28
+ return torch.stack((b1, b2, b3), dim=-1)
29
+
30
+
31
+ class Bottleneck(nn.Module):
32
+ """Redefinition of Bottleneck residual block
33
+ Adapted from the official PyTorch implementation
34
+ """
35
+
36
+ expansion = 4
37
+
38
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
39
+ super(Bottleneck, self).__init__()
40
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
41
+ self.bn1 = nn.BatchNorm2d(planes)
42
+ self.conv2 = nn.Conv2d(
43
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
44
+ )
45
+ self.bn2 = nn.BatchNorm2d(planes)
46
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
47
+ self.bn3 = nn.BatchNorm2d(planes * 4)
48
+ self.relu = nn.ReLU(inplace=True)
49
+ self.downsample = downsample
50
+ self.stride = stride
51
+
52
+ def forward(self, x):
53
+ residual = x
54
+
55
+ out = self.conv1(x)
56
+ out = self.bn1(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv2(out)
60
+ out = self.bn2(out)
61
+ out = self.relu(out)
62
+
63
+ out = self.conv3(out)
64
+ out = self.bn3(out)
65
+
66
+ if self.downsample is not None:
67
+ residual = self.downsample(x)
68
+
69
+ out += residual
70
+ out = self.relu(out)
71
+
72
+ return out
73
+
74
+
75
+ class HMR(nn.Module):
76
+ """SMPL Iterative Regressor with ResNet50 backbone"""
77
+
78
+ def __init__(self, block, layers, smpl_mean_params):
79
+ self.inplanes = 64
80
+ super(HMR, self).__init__()
81
+ self.n_shape = 10
82
+ self.n_cam = 3
83
+ self.n_joints = 24
84
+ npose = self.n_joints * 6
85
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
86
+ self.bn1 = nn.BatchNorm2d(64)
87
+ self.relu = nn.ReLU(inplace=True)
88
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
89
+ self.layer1 = self._make_layer(block, 64, layers[0])
90
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
91
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
92
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
93
+ self.avgpool = nn.AvgPool2d(7, stride=1)
94
+ self.fc1 = nn.Linear(512 * block.expansion + npose + self.n_shape + self.n_cam, 1024)
95
+ self.drop1 = nn.Dropout()
96
+ self.fc2 = nn.Linear(1024, 1024)
97
+ self.drop2 = nn.Dropout()
98
+ self.decpose = nn.Linear(1024, npose)
99
+ self.decshape = nn.Linear(1024, self.n_shape)
100
+ self.deccam = nn.Linear(1024, self.n_cam)
101
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
102
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
103
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
104
+
105
+ for m in self.modules():
106
+ if isinstance(m, nn.Conv2d):
107
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
109
+ elif isinstance(m, nn.BatchNorm2d):
110
+ m.weight.data.fill_(1)
111
+ m.bias.data.zero_()
112
+
113
+ mean_params = np.load(smpl_mean_params)
114
+ init_pose = torch.from_numpy(mean_params["pose"][:]).unsqueeze(0)
115
+ init_shape = torch.from_numpy(
116
+ mean_params["shape"][:].astype("float32")
117
+ ).unsqueeze(0)
118
+ init_cam = torch.from_numpy(mean_params["cam"]).unsqueeze(0)
119
+ self.register_buffer("init_pose", init_pose)
120
+ self.register_buffer("init_shape", init_shape)
121
+ self.register_buffer("init_cam", init_cam)
122
+
123
+ def _make_layer(self, block, planes, blocks, stride=1):
124
+ downsample = None
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ nn.Conv2d(
128
+ self.inplanes,
129
+ planes * block.expansion,
130
+ kernel_size=1,
131
+ stride=stride,
132
+ bias=False,
133
+ ),
134
+ nn.BatchNorm2d(planes * block.expansion),
135
+ )
136
+
137
+ layers = []
138
+ layers.append(block(self.inplanes, planes, stride, downsample))
139
+ self.inplanes = planes * block.expansion
140
+ for _ in range(1, blocks):
141
+ layers.append(block(self.inplanes, planes))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
146
+
147
+ batch_size = x.shape[0]
148
+
149
+ if init_pose is None:
150
+ init_pose = self.init_pose.expand(batch_size, -1)
151
+ if init_shape is None:
152
+ init_shape = self.init_shape.expand(batch_size, -1)
153
+ if init_cam is None:
154
+ init_cam = self.init_cam.expand(batch_size, -1)
155
+
156
+ x = self.conv1(x)
157
+ x = self.bn1(x)
158
+ x = self.relu(x)
159
+ x = self.maxpool(x)
160
+
161
+ x1 = self.layer1(x)
162
+ x2 = self.layer2(x1)
163
+ x3 = self.layer3(x2)
164
+ x4 = self.layer4(x3)
165
+
166
+ xf = self.avgpool(x4)
167
+ xf = xf.view(xf.size(0), -1)
168
+
169
+ pred_pose = init_pose
170
+ pred_shape = init_shape
171
+ pred_cam = init_cam
172
+ for _ in range(n_iter):
173
+ xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
174
+ xc = self.fc1(xc)
175
+ xc = self.drop1(xc)
176
+ xc = self.fc2(xc)
177
+ xc = self.drop2(xc)
178
+ pred_pose = self.decpose(xc) + pred_pose
179
+ pred_shape = self.decshape(xc) + pred_shape
180
+ pred_cam = self.deccam(xc) + pred_cam
181
+
182
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, self.n_joints, 3, 3)
183
+
184
+ return pred_rotmat, pred_shape, pred_cam
185
+
186
+
187
+ def hmr(smpl_mean_params, pretrained=True, **kwargs):
188
+ """Constructs an HMR model with ResNet50 backbone.
189
+ Args:
190
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
191
+ """
192
+ model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
193
+ if pretrained:
194
+ resnet_imagenet = resnet.resnet50(pretrained=True)
195
+ model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
196
+ return model