Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import spiga.models.gnn.pose_proj as pproj | |
| from spiga.models.cnn.cnn_multitask import MultitaskCNN | |
| from spiga.models.gnn.step_regressor import StepRegressor, RelativePositionEncoder | |
| class SPIGA(nn.Module): | |
| def __init__(self, num_landmarks=98, num_edges=15, steps=3, **kwargs): | |
| super(SPIGA, self).__init__() | |
| # Model parameters | |
| self.steps = steps # Cascaded regressors | |
| self.embedded_dim = 512 # GAT input channel | |
| self.nstack = 4 # Number of stacked GATs per step | |
| self.kwindow = 7 # Output cropped window dimension (kernel) | |
| self.swindow = 0.25 # Scale of the cropped window at first step (Dft. 25% w.r.t the input featuremap) | |
| self.offset_ratio = [self.swindow/(2**step)/2 for step in range(self.steps)] | |
| # CNN parameters | |
| self.num_landmarks = num_landmarks | |
| self.num_edges = num_edges | |
| # Initialize backbone | |
| self.visual_cnn = MultitaskCNN(num_landmarks=self.num_landmarks, num_edges=self.num_edges) | |
| # Features dimensions | |
| self.img_res = self.visual_cnn.img_res | |
| self.visual_res = self.visual_cnn.out_res | |
| self.visual_dim = self.visual_cnn.ch_dim | |
| # Initialize Pose head | |
| self.channels_pose = 6 | |
| self.pose_fc = nn.Linear(self.visual_cnn.ch_dim, self.channels_pose) | |
| # Initialize feature extractors: | |
| # Relative positional encoder | |
| shape_dim = 2 * (self.num_landmarks - 1) | |
| shape_encoder = [] | |
| for step in range(self.steps): | |
| shape_encoder.append(RelativePositionEncoder(shape_dim, self.embedded_dim, [256, 256])) | |
| self.shape_encoder = nn.ModuleList(shape_encoder) | |
| # Diagonal mask used to compute relative positions | |
| diagonal_mask = (torch.ones(self.num_landmarks, self.num_landmarks) - torch.eye(self.num_landmarks)).type(torch.bool) | |
| self.diagonal_mask = nn.parameter.Parameter(diagonal_mask, requires_grad=False) | |
| # Visual feature extractor | |
| conv_window = [] | |
| theta_S = [] | |
| for step in range(self.steps): | |
| # S matrix per step | |
| WH = self.visual_res # Width/height of ftmap | |
| Wout = self.swindow / (2 ** step) * WH # Width/height of the window | |
| K = self.kwindow # Kernel or resolution of the window | |
| scale = K / WH * (Wout - 1) / (K - 1) # Scale of the affine transformation | |
| # Rescale matrix S | |
| theta_S_stp = torch.tensor([[scale, 0], [0, scale]]) | |
| theta_S.append(nn.parameter.Parameter(theta_S_stp, requires_grad=False)) | |
| # Convolutional to embedded to BxLxCx1x1 | |
| conv_window.append(nn.Conv2d(self.visual_dim, self.embedded_dim, self.kwindow)) | |
| self.theta_S = nn.ParameterList(theta_S) | |
| self.conv_window = nn.ModuleList(conv_window) | |
| # Initialize GAT modules | |
| self.gcn = nn.ModuleList([StepRegressor(self.embedded_dim, 256, self.nstack) for i in range(self.steps)]) | |
| def forward(self, data): | |
| # Inputs: Visual features and points projections | |
| pts_proj, features = self.backbone_forward(data) | |
| # Visual field | |
| visual_field = features['VisualField'][-1] | |
| # Params compute only once | |
| gat_prob = [] | |
| features['Landmarks'] = [] | |
| for step in range(self.steps): | |
| # Features generation | |
| embedded_ft = self.extract_embedded(pts_proj, visual_field, step) | |
| # GAT inference | |
| offset, gat_prob = self.gcn[step](embedded_ft, gat_prob) | |
| offset = F.hardtanh(offset) | |
| # Update coordinates | |
| pts_proj = pts_proj + self.offset_ratio[step] * offset | |
| features['Landmarks'].append(pts_proj.clone()) | |
| features['GATProb'] = gat_prob | |
| return features | |
| def backbone_forward(self, data): | |
| # Inputs: Image and model3D | |
| imgs = data[0] | |
| model3d = data[1] | |
| cam_matrix = data[2] | |
| # HourGlass Forward | |
| features = self.visual_cnn(imgs) | |
| # Head pose estimation | |
| pose_raw = features['HGcore'][-1] | |
| B, L, _, _ = pose_raw.shape | |
| pose = pose_raw.reshape(B, L) | |
| pose = self.pose_fc(pose) | |
| features['Pose'] = pose.clone() | |
| # Project model 3D | |
| euler = pose[:, 0:3] | |
| trl = pose[:, 3:] | |
| rot = pproj.euler_to_rotation_matrix(euler) | |
| pts_proj = pproj.projectPoints(model3d, rot, trl, cam_matrix) | |
| pts_proj = pts_proj / self.visual_res | |
| return pts_proj, features | |
| def extract_embedded(self, pts_proj, receptive_field, step): | |
| # Visual features | |
| visual_ft = self.extract_visual_embedded(pts_proj, receptive_field, step) | |
| # Shape features | |
| shape_ft = self.calculate_distances(pts_proj) | |
| shape_ft = self.shape_encoder[step](shape_ft) | |
| # Addition | |
| embedded_ft = visual_ft + shape_ft | |
| return embedded_ft | |
| def extract_visual_embedded(self, pts_proj, receptive_field, step): | |
| # Affine matrix generation | |
| B, L, _ = pts_proj.shape # Pts_proj range:[0,1] | |
| centers = pts_proj + 0.5 / self.visual_res # BxLx2 | |
| centers = centers.reshape(B * L, 2) # B*Lx2 | |
| theta_trl = (-1 + centers * 2).unsqueeze(-1) # BxLx2x1 | |
| theta_s = self.theta_S[step] # 2x2 | |
| theta_s = theta_s.repeat(B * L, 1, 1) # B*Lx2x2 | |
| theta = torch.cat((theta_s, theta_trl), -1) # B*Lx2x3 | |
| # Generate crop grid | |
| B, C, _, _ = receptive_field.shape | |
| grid = torch.nn.functional.affine_grid(theta, (B * L, C, self.kwindow, self.kwindow)) | |
| grid = grid.reshape(B, L, self.kwindow, self.kwindow, 2) | |
| grid = grid.reshape(B, L, self.kwindow * self.kwindow, 2) | |
| # Crop windows | |
| crops = torch.nn.functional.grid_sample(receptive_field, grid, padding_mode="border") # BxCxLxK*K | |
| crops = crops.transpose(1, 2) # BxLxCxK*K | |
| crops = crops.reshape(B * L, C, self.kwindow, self.kwindow) | |
| # Flatten features | |
| visual_ft = self.conv_window[step](crops) | |
| _, Cout, _, _ = visual_ft.shape | |
| visual_ft = visual_ft.reshape(B, L, Cout) | |
| return visual_ft | |
| def calculate_distances(self, pts_proj): | |
| B, L, _ = pts_proj.shape # BxLx2 | |
| pts_a = pts_proj.unsqueeze(-2).repeat(1, 1, L, 1) | |
| pts_b = pts_a.transpose(1, 2) | |
| dist = pts_a - pts_b | |
| dist_wo_self = dist[:, self.diagonal_mask, :].reshape(B, L, -1) | |
| return dist_wo_self | |