full_gaussian_avatar / GHA /train_gaussianhead.py
pengc02's picture
all
ec9a6bc
raw
history blame
5 kB
import os
import torch
import argparse
from config.config import config_train
from lib.dataset.Dataset import GaussianDataset
from lib.dataset.DataLoaderX import DataLoaderX
from lib.module.MeshHeadModule import MeshHeadModule
from lib.module.GaussianHeadModule import GaussianHeadModule
from lib.module.SuperResolutionModule import SuperResolutionModule
from lib.module.CameraModule import CameraModule
from lib.recorder.Recorder import GaussianHeadTrainRecorder
from lib.trainer.GaussianHeadTrainer import GaussianHeadTrainer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/train_s2_N031.yaml')
arg = parser.parse_args()
cfg = config_train()
cfg.load(arg.config)
cfg = cfg.get_cfg()
dataset = GaussianDataset(cfg.dataset)
dataloader = DataLoaderX(dataset, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=8)
device = torch.device('cuda:%d' % cfg.gpu_id)
torch.cuda.set_device(cfg.gpu_id)
if os.path.exists(cfg.load_gaussianhead_checkpoint):
gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule,
xyz=gaussianhead_state_dict['xyz'],
feature=gaussianhead_state_dict['feature'],
landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
gaussianhead.load_state_dict(gaussianhead_state_dict)
else:
meshhead_state_dict = torch.load(cfg.load_meshhead_checkpoint, map_location=lambda storage, loc: storage)
meshhead = MeshHeadModule(cfg.meshheadmodule, meshhead_state_dict['landmarks_3d_neutral']).to(device)
meshhead.load_state_dict(meshhead_state_dict)
meshhead.subdivide()
with torch.no_grad():
data = meshhead.reconstruct_neutral()
gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule,
xyz=data['verts'].cpu(),
feature=torch.atanh(data['verts_feature'].cpu()),
landmarks_3d_neutral=meshhead.landmarks_3d_neutral.detach().cpu(),
add_mouth_points=True).to(device)
gaussianhead.exp_color_mlp.load_state_dict(meshhead.exp_color_mlp.state_dict())
gaussianhead.pose_color_mlp.load_state_dict(meshhead.pose_color_mlp.state_dict())
gaussianhead.exp_deform_mlp.load_state_dict(meshhead.exp_deform_mlp.state_dict())
gaussianhead.pose_deform_mlp.load_state_dict(meshhead.pose_deform_mlp.state_dict())
supres = SuperResolutionModule(cfg.supresmodule).to(device)
if os.path.exists(cfg.load_supres_checkpoint):
supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage))
camera = CameraModule()
recorder = GaussianHeadTrainRecorder(cfg.recorder)
optimized_parameters = [{'params' : supres.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.xyz, 'lr' : cfg.lr_net * 0.1},
{'params' : gaussianhead.feature, 'lr' : cfg.lr_net * 0.1},
{'params' : gaussianhead.exp_color_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.pose_color_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.exp_deform_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.pose_deform_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.exp_attributes_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.pose_attributes_mlp.parameters(), 'lr' : cfg.lr_net},
{'params' : gaussianhead.scales, 'lr' : cfg.lr_net * 0.3},
{'params' : gaussianhead.rotation, 'lr' : cfg.lr_net * 0.1},
{'params' : gaussianhead.opacity, 'lr' : cfg.lr_net}]
if os.path.exists(cfg.load_delta_poses_checkpoint):
delta_poses = torch.load(cfg.load_delta_poses_checkpoint)
else:
delta_poses = torch.zeros([dataset.num_exp_id, 6]).to(device)
if cfg.optimize_pose:
delta_poses = delta_poses.requires_grad_(True)
optimized_parameters.append({'params' : delta_poses, 'lr' : cfg.lr_pose})
else:
delta_poses = delta_poses.requires_grad_(False)
optimizer = torch.optim.Adam(optimized_parameters)
trainer = GaussianHeadTrainer(dataloader, delta_poses, gaussianhead, supres, camera, optimizer, recorder, cfg.gpu_id)
trainer.train(0, 1000)