Spaces:
Running
Running
import pkg_resources | |
import json | |
import copy | |
import torch | |
import spiga.data.loaders.dl_config as dl_cfg | |
import spiga.data.loaders.dataloader as dl | |
import spiga.inference.pretreatment as pretreat | |
from spiga.inference.framework import SPIGAFramework | |
from spiga.inference.config import ModelConfig | |
def main(): | |
import argparse | |
pars = argparse.ArgumentParser(description='Experiment results generator') | |
pars.add_argument('database', type=str, help='Database name', | |
choices=['wflw', '300wpublic', '300wprivate', "merlrav", "cofw68"]) | |
pars.add_argument('-a','--anns', type=str, default='test', help='Annotations type: test, valid or train') | |
pars.add_argument('--gpus', type=int, default=0, help='GPU Id') | |
args = pars.parse_args() | |
# Load model framework | |
model_cfg = ModelConfig(args.database) | |
model_framework = SPIGAFramework(model_cfg, gpus=[args.gpus]) | |
# Generate results | |
tester = Tester(model_framework, args.database, anns_type=args.anns) | |
with torch.no_grad(): | |
tester.generate_results() | |
class Tester: | |
def __init__(self, model_framework, database, anns_type='test'): | |
# Parameters | |
self.anns_type = anns_type | |
self.database = database | |
# Model initialization | |
self.model_framework = model_framework | |
# Dataloader | |
self.dl_eval = dl_cfg.AlignConfig(self.database, mode=self.anns_type) | |
self.dl_eval.aug_names = [] | |
self.dl_eval.shuffle = False | |
self.dl_eval.target_dist = self.model_framework.model_cfg.target_dist | |
self.dl_eval.image_size = self.model_framework.model_cfg.image_size | |
self.dl_eval.ftmap_size = self.model_framework.model_cfg.ftmap_size | |
self.batch_size = 1 | |
self.test_data, _ = dl.get_dataloader(self.batch_size, self.dl_eval, | |
pretreat=pretreat.NormalizeAndPermute(), debug=True) | |
# Results | |
self.data_struc = {'imgpath': str, 'bbox': None, 'headpose': None, 'ids': None, 'landmarks': None, 'visible': None} | |
self.result_path = pkg_resources.resource_filename('spiga', 'eval/results') | |
self.result_file = '/results_%s_%s.json' % (self.database, self.anns_type) | |
self.file_out = self.result_path + self.result_file | |
def generate_results(self): | |
data = [] | |
for step, batch in enumerate(self.test_data): | |
print('Step: ', step) | |
inputs = self.model_framework.select_inputs(batch) | |
outputs_raw = self.model_framework.net_forward(inputs) | |
# Postprocessing | |
outputs = self.model_framework.postreatment(outputs_raw, batch['bbox'], batch['bbox_raw']) | |
# Data | |
data_dict = copy.deepcopy(self.data_struc) | |
data_dict['imgpath'] = batch['imgpath_local'][0] | |
data_dict['bbox'] = batch['bbox_raw'][0].numpy().tolist() | |
data_dict['visible'] = batch['visible'][0].numpy().tolist() | |
data_dict['ids'] = self.dl_eval.database.ldm_ids | |
data_dict['landmarks'] = outputs['landmarks'][0] | |
data_dict['headpose'] = outputs['headpose'][0] | |
data.append(data_dict) | |
# Save outputs | |
with open(self.file_out, 'w') as outfile: | |
json.dump(data, outfile) | |
if __name__ == '__main__': | |
main() | |