File size: 2,084 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
import argparse

from config.config import config_train

from lib.dataset.Dataset import MeshDataset
from lib.dataset.DataLoaderX import DataLoaderX
from lib.module.MeshHeadModule import MeshHeadModule
from lib.module.CameraModule import CameraModule
from lib.recorder.Recorder import MeshHeadTrainRecorder
from lib.trainer.MeshHeadTrainer import MeshHeadTrainer

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/train_s1_N031.yaml')
    arg = parser.parse_args()

    cfg = config_train()
    cfg.load(arg.config)
    cfg = cfg.get_cfg()

    dataset = MeshDataset(cfg.dataset)
    dataloader = DataLoaderX(dataset, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=8) 

    device = torch.device('cuda:%d' % cfg.gpu_id)
    torch.cuda.set_device(cfg.gpu_id)

    meshhead = MeshHeadModule(cfg.meshheadmodule, dataset.init_landmarks_3d_neutral).to(device)
    if os.path.exists(cfg.load_meshhead_checkpoint):
        meshhead.load_state_dict(torch.load(cfg.load_meshhead_checkpoint, map_location=lambda storage, loc: storage))
    else:
        meshhead.pre_train_sphere(300, device)
    
    camera = CameraModule()
    recorder = MeshHeadTrainRecorder(cfg.recorder)

    optimizer = torch.optim.Adam([{'params' : meshhead.landmarks_3d_neutral, 'lr' : cfg.lr_lmk},
                                  {'params' : meshhead.geo_mlp.parameters(), 'lr' : cfg.lr_net},
                                  {'params' : meshhead.exp_color_mlp.parameters(), 'lr' : cfg.lr_net},
                                  {'params' : meshhead.pose_color_mlp.parameters(), 'lr' : cfg.lr_net},
                                  {'params' : meshhead.exp_deform_mlp.parameters(), 'lr' : cfg.lr_net},
                                  {'params' : meshhead.pose_deform_mlp.parameters(), 'lr' : cfg.lr_net}])
    trainer = MeshHeadTrainer(dataloader, meshhead, camera, optimizer, recorder, cfg.gpu_id)
    trainer.train(0, 50)