File size: 2,087 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
import argparse

from config.config import config_reenactment

from lib.dataset.Dataset import ReenactmentDataset
from lib.dataset.DataLoaderX import DataLoaderX
from lib.module.GaussianHeadModule import GaussianHeadModule
from lib.module.SuperResolutionModule import SuperResolutionModule
from lib.module.CameraModule import CameraModule
from lib.recorder.Recorder import ReenactmentRecorder
from lib.apps.Reenactment import Reenactment

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/reenactment_N031.yaml')
    parser.add_argument('--offline_rendering_param_fpath', type=str, default=None)
    arg = parser.parse_args()

    cfg = config_reenactment()
    cfg.load(arg.config)
    cfg = cfg.get_cfg()

    dataset = ReenactmentDataset(cfg.dataset)
    dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) 

    device = torch.device('cuda:%d' % cfg.gpu_id)

    gaussianhead_state_dict = torch.load(cfg.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
    gaussianhead = GaussianHeadModule(cfg.gaussianheadmodule, 
                                          xyz=gaussianhead_state_dict['xyz'], 
                                          feature=gaussianhead_state_dict['feature'],
                                          landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
    gaussianhead.load_state_dict(gaussianhead_state_dict)

    supres = SuperResolutionModule(cfg.supresmodule).to(device)
    supres.load_state_dict(torch.load(cfg.load_supres_checkpoint, map_location=lambda storage, loc: storage))

    camera = CameraModule()
    recorder = ReenactmentRecorder(cfg.recorder)

    app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, cfg.gpu_id, dataset.freeview)
    if arg.offline_rendering_param_fpath is None:
        app.run(stop_fid=800)
    else:
        app.run_for_offline_stitching(arg.offline_rendering_param_fpath)