File size: 2,338 Bytes
9390e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from collections import OrderedDict

from spiga.data.loaders.dl_config import DatabaseStruct

MODELS_URL = {'wflw': 'https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7',
              '300wpublic': 'https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC',
              '300wprivate': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM',
              'merlrav': 'https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6',
              'cofw68': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM'}


class ModelConfig(object):

    def __init__(self, dataset_name=None, load_model_url=True):
        # Model configuration
        self.model_weights = None
        self.model_weights_path = None
        self.load_model_url = load_model_url
        self.model_weights_url = None
        # Pretreatment
        self.focal_ratio = 1.5          # Camera matrix focal length ratio.
        self.target_dist = 1.6          # Target distance zoom in/out around face.
        self.image_size = (256, 256)
        # Outputs
        self.ftmap_size = (64, 64)
        # Dataset
        self.dataset = None

        if dataset_name is not None:
            self.update_with_dataset(dataset_name)

    def update_with_dataset(self, dataset_name):

        config_dict = {'dataset': DatabaseStruct(dataset_name),
                       'model_weights': 'spiga_%s.pt' % dataset_name}

        if dataset_name == 'cofw68':     # Test only
            config_dict['model_weights'] = 'spiga_300wprivate.pt'

        if self.load_model_url:
            config_dict['model_weights_url'] = MODELS_URL[dataset_name]

        self.update(config_dict)

    def update(self, params_dict):
        state_dict = self.state_dict()
        for k, v in params_dict.items():
            if k in state_dict or hasattr(self, k):
                setattr(self, k, v)
            else:
                raise Warning('Unknown option: {}: {}'.format(k, v))

    def state_dict(self):
        state_dict = OrderedDict()
        for k in self.__dict__.keys():
            if not k.startswith('_'):
                state_dict[k] = getattr(self, k)
        return state_dict