full_gaussian_avatar / GHA /reenactment.py
pengc02's picture
all
ec9a6bc
raw
history blame
2.09 kB
import os
import torch
import argparse
from config.config import config_reenactment
from lib.dataset.Dataset import ReenactmentDataset
from lib.dataset.DataLoaderX import DataLoaderX
from lib.module.GaussianHeadModule import GaussianHeadModule
from lib.module.SuperResolutionModule import SuperResolutionModule
from lib.module.CameraModule import CameraModule
from lib.recorder.Recorder import ReenactmentRecorder
from lib.apps.Reenactment import Reenactment
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/reenactment_N031.yaml')
parser.add_argument('--offline_rendering_param_fpath', type=str, default=None)
arg = parser.parse_args()
cfg = config_reenactment()
cfg.load(arg.config)
cfg = cfg.get_cfg()
dataset = ReenactmentDataset(cfg.dataset)
dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True)
device = torch.device('cuda:%d' % cfg.gpu_id)
gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule,
xyz=gaussianhead_state_dict['xyz'],
feature=gaussianhead_state_dict['feature'],
landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
gaussianhead.load_state_dict(gaussianhead_state_dict)
supres = SuperResolutionModule(cfg.supresmodule).to(device)
supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage))
camera = CameraModule()
recorder = ReenactmentRecorder(cfg.recorder)
app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, cfg.gpu_id, dataset.freeview)
if arg.offline_rendering_param_fpath is None:
app.run(stop_fid=800)
else:
app.run_for_offline_stitching(arg.offline_rendering_param_fpath)