File size: 1,227 Bytes
4b96359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b1ae50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from videoretalking.models.DNet import DNet
from videoretalking.models.LNet import LNet
from videoretalking.models.ENet import ENet


def _load(checkpoint_path):
    map_location=None if torch.cuda.is_available() else torch.device('cpu')
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    return checkpoint

def load_checkpoint(path, model):
    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
    new_s = {}
    for k, v in s.items():
        if 'low_res' in k:
            continue
        else:
            new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s, strict=False)
    return model

def load_network(LNet_path,ENet_path):
    L_net = LNet()
    L_net = load_checkpoint(LNet_path, L_net)
    E_net = ENet(lnet=L_net)
    model = load_checkpoint(ENet_path, E_net)
    return model.eval()

def load_DNet(DNet_path):
    D_Net = DNet()
    print("Load checkpoint from: {}".format(DNet_path))
    checkpoint =  torch.load(DNet_path, map_location=lambda storage, loc: storage)
    D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
    return D_Net.eval()