svjack's picture
Upload SPIGA with huggingface_hub
9390e2c
raw
history blame
2.34 kB
from collections import OrderedDict
from spiga.data.loaders.dl_config import DatabaseStruct
MODELS_URL = {'wflw': 'https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7',
'300wpublic': 'https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC',
'300wprivate': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM',
'merlrav': 'https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6',
'cofw68': 'https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM'}
class ModelConfig(object):
def __init__(self, dataset_name=None, load_model_url=True):
# Model configuration
self.model_weights = None
self.model_weights_path = None
self.load_model_url = load_model_url
self.model_weights_url = None
# Pretreatment
self.focal_ratio = 1.5 # Camera matrix focal length ratio.
self.target_dist = 1.6 # Target distance zoom in/out around face.
self.image_size = (256, 256)
# Outputs
self.ftmap_size = (64, 64)
# Dataset
self.dataset = None
if dataset_name is not None:
self.update_with_dataset(dataset_name)
def update_with_dataset(self, dataset_name):
config_dict = {'dataset': DatabaseStruct(dataset_name),
'model_weights': 'spiga_%s.pt' % dataset_name}
if dataset_name == 'cofw68': # Test only
config_dict['model_weights'] = 'spiga_300wprivate.pt'
if self.load_model_url:
config_dict['model_weights_url'] = MODELS_URL[dataset_name]
self.update(config_dict)
def update(self, params_dict):
state_dict = self.state_dict()
for k, v in params_dict.items():
if k in state_dict or hasattr(self, k):
setattr(self, k, v)
else:
raise Warning('Unknown option: {}: {}'.format(k, v))
def state_dict(self):
state_dict = OrderedDict()
for k in self.__dict__.keys():
if not k.startswith('_'):
state_dict[k] = getattr(self, k)
return state_dict