File size: 5,653 Bytes
9390e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from spiga.inference.config import ModelConfig
from spiga.models.spiga import SPIGA
import spiga.inference.pretreatment as pretreat
import os
import pkg_resources
import copy
import torch
import numpy as np

# Paths
weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights')


class SPIGAFramework:

    def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):

        # Parameters
        self.model_cfg = model_cfg
        self.gpus = gpus

        # Pretreatment initialization
        self.transforms = pretreat.get_transformers(self.model_cfg)

        # SPIGA model
        self.model_inputs = ['image', "model3d", "cam_matrix"]
        self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks,
                           num_edges=model_cfg.dataset.num_edges)

        # Load weights and set model
        weights_path = self.model_cfg.model_weights_path
        if weights_path is None:
            weights_path = weights_path_dft

        if self.model_cfg.load_model_url:
            model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url,
                                                                  model_dir=weights_path,
                                                                  file_name=self.model_cfg.model_weights)
        else:
            weights_file = os.path.join(
                weights_path, self.model_cfg.model_weights)
            model_state_dict = torch.load(weights_file)

        self.model.load_state_dict(model_state_dict)
        # self.model = self.model.cuda(gpus[0])
        self.model = self.model.cuda(
            gpus[0]) if torch.cuda.is_available() else self.model
        self.model.eval()
        print('SPIGA model loaded!')

        # Load 3D model and camera intrinsic matrix
        if load3DM:
            loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids,
                                             ftmap_size=model_cfg.ftmap_size,
                                             focal_ratio=model_cfg.focal_ratio,
                                             totensor=True)
            params_3DM = self._data2device(loader_3DM())
            self.model3d = params_3DM['model3d']
            self.cam_matrix = params_3DM['cam_matrix']

    def inference(self, image, bboxes):
        """
        @param self:
        @param image: Raw image
        @param bboxes: List of bounding box founded on the image [[x,y,w,h],...]
        @return: features dict {'landmarks': list with shape (num_bbox, num_landmarks, 2) and x,y referred to image size
                                'headpose': list with shape (num_bbox, 6) euler->[:3], trl->[3:]
        """
        batch_crops, crop_bboxes = self.pretreat(image, bboxes)
        outputs = self.net_forward(batch_crops)
        features = self.postreatment(outputs, crop_bboxes, bboxes)
        return features

    def pretreat(self, image, bboxes):
        crop_bboxes = []
        crop_images = []
        for bbox in bboxes:
            sample = {'image': copy.deepcopy(image),
                      'bbox': copy.deepcopy(bbox)}
            sample_crop = self.transforms(sample)
            crop_bboxes.append(sample_crop['bbox'])
            crop_images.append(sample_crop['image'])

        # Images to tensor and device
        batch_images = torch.tensor(np.array(crop_images), dtype=torch.float)
        batch_images = self._data2device(batch_images)
        # Batch 3D model and camera intrinsic matrix
        batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1)
        batch_cam_matrix = self.cam_matrix.unsqueeze(
            0).repeat(len(bboxes), 1, 1)

        # SPIGA inputs
        model_inputs = [batch_images, batch_model3D, batch_cam_matrix]
        return model_inputs, crop_bboxes

    def net_forward(self, inputs):
        outputs = self.model(inputs)
        return outputs

    def postreatment(self, output, crop_bboxes, bboxes):
        features = {}
        crop_bboxes = np.array(crop_bboxes)
        bboxes = np.array(bboxes)

        if 'Landmarks' in output.keys():
            landmarks = output['Landmarks'][-1].cpu().detach().numpy()
            landmarks = landmarks.transpose((1, 0, 2))
            landmarks = landmarks*self.model_cfg.image_size
            landmarks_norm = (
                landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4]
            landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2]
            landmarks_out = landmarks_out.transpose((1, 0, 2))
            features['landmarks'] = landmarks_out.tolist()

        # Pose output
        if 'Pose' in output.keys():
            pose = output['Pose'].cpu().detach().numpy()
            features['headpose'] = pose.tolist()

        return features

    def select_inputs(self, batch):
        inputs = []
        for ft_name in self.model_inputs:
            data = batch[ft_name]
            inputs.append(self._data2device(data.type(torch.float)))
        return inputs

    def _data2device(self, data):
        if isinstance(data, list):
            data_var = data
            for data_id, v_data in enumerate(data):
                data_var[data_id] = self._data2device(v_data)
        if isinstance(data, dict):
            data_var = data
            for k, v in data.items():
                data[k] = self._data2device(v)
        else:
            with torch.no_grad():
                if torch.cuda.is_available():
                    data_var = data.cuda(
                        device=self.gpus[0], non_blocking=True)
                else:
                    data_var = data
        return data_var