svjack's picture
Upload SPIGA with huggingface_hub
9390e2c
raw
history blame
3.35 kB
import pkg_resources
import json
import copy
import torch
import spiga.data.loaders.dl_config as dl_cfg
import spiga.data.loaders.dataloader as dl
import spiga.inference.pretreatment as pretreat
from spiga.inference.framework import SPIGAFramework
from spiga.inference.config import ModelConfig
def main():
import argparse
pars = argparse.ArgumentParser(description='Experiment results generator')
pars.add_argument('database', type=str, help='Database name',
choices=['wflw', '300wpublic', '300wprivate', "merlrav", "cofw68"])
pars.add_argument('-a','--anns', type=str, default='test', help='Annotations type: test, valid or train')
pars.add_argument('--gpus', type=int, default=0, help='GPU Id')
args = pars.parse_args()
# Load model framework
model_cfg = ModelConfig(args.database)
model_framework = SPIGAFramework(model_cfg, gpus=[args.gpus])
# Generate results
tester = Tester(model_framework, args.database, anns_type=args.anns)
with torch.no_grad():
tester.generate_results()
class Tester:
def __init__(self, model_framework, database, anns_type='test'):
# Parameters
self.anns_type = anns_type
self.database = database
# Model initialization
self.model_framework = model_framework
# Dataloader
self.dl_eval = dl_cfg.AlignConfig(self.database, mode=self.anns_type)
self.dl_eval.aug_names = []
self.dl_eval.shuffle = False
self.dl_eval.target_dist = self.model_framework.model_cfg.target_dist
self.dl_eval.image_size = self.model_framework.model_cfg.image_size
self.dl_eval.ftmap_size = self.model_framework.model_cfg.ftmap_size
self.batch_size = 1
self.test_data, _ = dl.get_dataloader(self.batch_size, self.dl_eval,
pretreat=pretreat.NormalizeAndPermute(), debug=True)
# Results
self.data_struc = {'imgpath': str, 'bbox': None, 'headpose': None, 'ids': None, 'landmarks': None, 'visible': None}
self.result_path = pkg_resources.resource_filename('spiga', 'eval/results')
self.result_file = '/results_%s_%s.json' % (self.database, self.anns_type)
self.file_out = self.result_path + self.result_file
def generate_results(self):
data = []
for step, batch in enumerate(self.test_data):
print('Step: ', step)
inputs = self.model_framework.select_inputs(batch)
outputs_raw = self.model_framework.net_forward(inputs)
# Postprocessing
outputs = self.model_framework.postreatment(outputs_raw, batch['bbox'], batch['bbox_raw'])
# Data
data_dict = copy.deepcopy(self.data_struc)
data_dict['imgpath'] = batch['imgpath_local'][0]
data_dict['bbox'] = batch['bbox_raw'][0].numpy().tolist()
data_dict['visible'] = batch['visible'][0].numpy().tolist()
data_dict['ids'] = self.dl_eval.database.ldm_ids
data_dict['landmarks'] = outputs['landmarks'][0]
data_dict['headpose'] = outputs['headpose'][0]
data.append(data_dict)
# Save outputs
with open(self.file_out, 'w') as outfile:
json.dump(data, outfile)
if __name__ == '__main__':
main()