Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import argparse | |
from config.config import config_reenactment | |
from lib.dataset.Dataset import ReenactmentDataset | |
from lib.dataset.DataLoaderX import DataLoaderX | |
from lib.module.GaussianHeadModule import GaussianHeadModule | |
from lib.module.SuperResolutionModule import SuperResolutionModule | |
from lib.module.CameraModule import CameraModule | |
from lib.recorder.Recorder import ReenactmentRecorder | |
from lib.apps.Reenactment import Reenactment | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str, default='config/reenactment_N031.yaml') | |
parser.add_argument('--offline_rendering_param_fpath', type=str, default=None) | |
arg = parser.parse_args() | |
cfg = config_reenactment() | |
cfg.load(arg.config) | |
cfg = cfg.get_cfg() | |
dataset = ReenactmentDataset(cfg.dataset) | |
dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) | |
device = torch.device('cuda:%d' % cfg.gpu_id) | |
gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) | |
gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule, | |
xyz=gaussianhead_state_dict['xyz'], | |
feature=gaussianhead_state_dict['feature'], | |
landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) | |
gaussianhead.load_state_dict(gaussianhead_state_dict) | |
supres = SuperResolutionModule(cfg.supresmodule).to(device) | |
supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage)) | |
camera = CameraModule() | |
recorder = ReenactmentRecorder(cfg.recorder) | |
app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, cfg.gpu_id, dataset.freeview) | |
if arg.offline_rendering_param_fpath is None: | |
app.run(stop_fid=800) | |
else: | |
app.run_for_offline_stitching(arg.offline_rendering_param_fpath) | |