import os import torch import argparse from config.config import config_train from lib.dataset.Dataset import GaussianDataset from lib.dataset.DataLoaderX import DataLoaderX from lib.module.MeshHeadModule import MeshHeadModule from lib.module.GaussianHeadModule import GaussianHeadModule from lib.module.SuperResolutionModule import SuperResolutionModule from lib.module.CameraModule import CameraModule from lib.recorder.Recorder import GaussianHeadTrainRecorder from lib.trainer.GaussianHeadTrainer import GaussianHeadTrainer if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/train_s2_N031.yaml') arg = parser.parse_args() cfg = config_train() cfg.load(arg.config) cfg = cfg.get_cfg() dataset = GaussianDataset(cfg.dataset) dataloader = DataLoaderX(dataset, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=8) device = torch.device('cuda:%d' % cfg.gpu_id) torch.cuda.set_device(cfg.gpu_id) if os.path.exists(cfg.load_gaussianhead_checkpoint): gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule, xyz=gaussianhead_state_dict['xyz'], feature=gaussianhead_state_dict['feature'], landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) gaussianhead.load_state_dict(gaussianhead_state_dict) else: meshhead_state_dict = torch.load(cfg.load_meshhead_checkpoint, map_location=lambda storage, loc: storage) meshhead = MeshHeadModule(cfg.meshheadmodule, meshhead_state_dict['landmarks_3d_neutral']).to(device) meshhead.load_state_dict(meshhead_state_dict) meshhead.subdivide() with torch.no_grad(): data = meshhead.reconstruct_neutral() gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule, xyz=data['verts'].cpu(), feature=torch.atanh(data['verts_feature'].cpu()), landmarks_3d_neutral=meshhead.landmarks_3d_neutral.detach().cpu(), add_mouth_points=True).to(device) gaussianhead.exp_color_mlp.load_state_dict(meshhead.exp_color_mlp.state_dict()) gaussianhead.pose_color_mlp.load_state_dict(meshhead.pose_color_mlp.state_dict()) gaussianhead.exp_deform_mlp.load_state_dict(meshhead.exp_deform_mlp.state_dict()) gaussianhead.pose_deform_mlp.load_state_dict(meshhead.pose_deform_mlp.state_dict()) supres = SuperResolutionModule(cfg.supresmodule).to(device) if os.path.exists(cfg.load_supres_checkpoint): supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage)) camera = CameraModule() recorder = GaussianHeadTrainRecorder(cfg.recorder) optimized_parameters = [{'params' : supres.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.xyz, 'lr' : cfg.lr_net * 0.1}, {'params' : gaussianhead.feature, 'lr' : cfg.lr_net * 0.1}, {'params' : gaussianhead.exp_color_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.pose_color_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.exp_deform_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.pose_deform_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.exp_attributes_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.pose_attributes_mlp.parameters(), 'lr' : cfg.lr_net}, {'params' : gaussianhead.scales, 'lr' : cfg.lr_net * 0.3}, {'params' : gaussianhead.rotation, 'lr' : cfg.lr_net * 0.1}, {'params' : gaussianhead.opacity, 'lr' : cfg.lr_net}] if os.path.exists(cfg.load_delta_poses_checkpoint): delta_poses = torch.load(cfg.load_delta_poses_checkpoint) else: delta_poses = torch.zeros([dataset.num_exp_id, 6]).to(device) if cfg.optimize_pose: delta_poses = delta_poses.requires_grad_(True) optimized_parameters.append({'params' : delta_poses, 'lr' : cfg.lr_pose}) else: delta_poses = delta_poses.requires_grad_(False) optimizer = torch.optim.Adam(optimized_parameters) trainer = GaussianHeadTrainer(dataloader, delta_poses, gaussianhead, supres, camera, optimizer, recorder, cfg.gpu_id) trainer.train(0, 1000)