import os import torch import argparse from config.config import config_train from lib.dataset.Dataset import MeshDataset from lib.dataset.DataLoaderX import DataLoaderX from lib.module.MeshHeadModule import MeshHeadModule from lib.module.CameraModule import CameraModule from lib.recorder.Recorder import MeshHeadTrainRecorder from lib.trainer.MeshHeadTrainer import MeshHeadTrainer if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/train_s1_N031.yaml') arg = parser.parse_args() cfg = config_train() cfg.load(arg.config) cfg = cfg.get_cfg() dataset = MeshDataset(cfg.dataset) dataloader = DataLoaderX(dataset, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=8) device = torch.device('cuda:%d' % cfg.gpu_id) torch.cuda.set_device(cfg.gpu_id) meshhead = MeshHeadModule(cfg.meshheadmodule, dataset.init_landmarks_3d_neutral).to(device) if os.path.exists(cfg.load_meshhead_checkpoint): meshhead.load_state_dict(torch.load(cfg.load_meshhead_checkpoint, map_location=lambda storage, loc: storage)) else: meshhead.pre_train_sphere(300, device) camera = CameraModule() recorder = MeshHeadTrainRecorder(cfg.recorder) optimizer = torch.optim.Adam([{'params' : meshhead.landmarks_3d_neutral, 'lr' : cfg.lr_lmk}, {'params' : meshhead.geo_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : meshhead.exp_color_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : meshhead.pose_color_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : meshhead.exp_deform_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : meshhead.pose_deform_mlp.parameters(), 'lr' : cfg.lr_net}]) trainer = MeshHeadTrainer(dataloader, meshhead, camera, optimizer, recorder, cfg.gpu_id) trainer.train(0, 50)