Spaces:
Paused
Paused
File size: 1,227 Bytes
4b96359 5b1ae50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
from videoretalking.models.DNet import DNet
from videoretalking.models.LNet import LNet
from videoretalking.models.ENet import ENet
def _load(checkpoint_path):
map_location=None if torch.cuda.is_available() else torch.device('cpu')
checkpoint = torch.load(checkpoint_path, map_location=map_location)
return checkpoint
def load_checkpoint(path, model):
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
new_s = {}
for k, v in s.items():
if 'low_res' in k:
continue
else:
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s, strict=False)
return model
def load_network(LNet_path,ENet_path):
L_net = LNet()
L_net = load_checkpoint(LNet_path, L_net)
E_net = ENet(lnet=L_net)
model = load_checkpoint(ENet_path, E_net)
return model.eval()
def load_DNet(DNet_path):
D_Net = DNet()
print("Load checkpoint from: {}".format(DNet_path))
checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
return D_Net.eval() |