File size: 5,653 Bytes
9390e2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from spiga.inference.config import ModelConfig
from spiga.models.spiga import SPIGA
import spiga.inference.pretreatment as pretreat
import os
import pkg_resources
import copy
import torch
import numpy as np
# Paths
weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights')
class SPIGAFramework:
def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True):
# Parameters
self.model_cfg = model_cfg
self.gpus = gpus
# Pretreatment initialization
self.transforms = pretreat.get_transformers(self.model_cfg)
# SPIGA model
self.model_inputs = ['image', "model3d", "cam_matrix"]
self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks,
# Load weights and set model
weights_path = self.model_cfg.model_weights_path
if weights_path is None:
weights_path = weights_path_dft
if self.model_cfg.load_model_url:
model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url,
weights_file = os.path.join(
weights_path, self.model_cfg.model_weights)
model_state_dict = torch.load(weights_file)
# self.model = self.model.cuda(gpus[0])
self.model = self.model.cuda(
gpus[0]) if torch.cuda.is_available() else self.model
print('SPIGA model loaded!')
# Load 3D model and camera intrinsic matrix
if load3DM:
loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids,
params_3DM = self._data2device(loader_3DM())
self.model3d = params_3DM['model3d']
self.cam_matrix = params_3DM['cam_matrix']
def inference(self, image, bboxes):
@param self:
@param image: Raw image
@param bboxes: List of bounding box founded on the image [[x,y,w,h],...]
@return: features dict {'landmarks': list with shape (num_bbox, num_landmarks, 2) and x,y referred to image size
'headpose': list with shape (num_bbox, 6) euler->[:3], trl->[3:]
batch_crops, crop_bboxes = self.pretreat(image, bboxes)
outputs = self.net_forward(batch_crops)
features = self.postreatment(outputs, crop_bboxes, bboxes)
return features
def pretreat(self, image, bboxes):
crop_bboxes = []
crop_images = []
for bbox in bboxes:
sample = {'image': copy.deepcopy(image),
'bbox': copy.deepcopy(bbox)}
sample_crop = self.transforms(sample)
# Images to tensor and device
batch_images = torch.tensor(np.array(crop_images), dtype=torch.float)
batch_images = self._data2device(batch_images)
# Batch 3D model and camera intrinsic matrix
batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1)
batch_cam_matrix = self.cam_matrix.unsqueeze(
0).repeat(len(bboxes), 1, 1)
# SPIGA inputs
model_inputs = [batch_images, batch_model3D, batch_cam_matrix]
return model_inputs, crop_bboxes
def net_forward(self, inputs):
outputs = self.model(inputs)
return outputs
def postreatment(self, output, crop_bboxes, bboxes):
features = {}
crop_bboxes = np.array(crop_bboxes)
bboxes = np.array(bboxes)
if 'Landmarks' in output.keys():
landmarks = output['Landmarks'][-1].cpu().detach().numpy()
landmarks = landmarks.transpose((1, 0, 2))
landmarks = landmarks*self.model_cfg.image_size
landmarks_norm = (
landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4]
landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2]
landmarks_out = landmarks_out.transpose((1, 0, 2))
features['landmarks'] = landmarks_out.tolist()
# Pose output
if 'Pose' in output.keys():
pose = output['Pose'].cpu().detach().numpy()
features['headpose'] = pose.tolist()
return features
def select_inputs(self, batch):
inputs = []
for ft_name in self.model_inputs:
data = batch[ft_name]
return inputs
def _data2device(self, data):
if isinstance(data, list):
data_var = data
for data_id, v_data in enumerate(data):
data_var[data_id] = self._data2device(v_data)
if isinstance(data, dict):
data_var = data
for k, v in data.items():
data[k] = self._data2device(v)
with torch.no_grad():
if torch.cuda.is_available():
data_var = data.cuda(
device=self.gpus[0], non_blocking=True)
data_var = data
return data_var