import torch from videoretalking.models.DNet import DNet from videoretalking.models.LNet import LNet from videoretalking.models.ENet import ENet def _load(checkpoint_path): map_location=None if torch.cuda.is_available() else torch.device('cpu') checkpoint = torch.load(checkpoint_path, map_location=map_location) return checkpoint def load_checkpoint(path, model): print("Load checkpoint from: {}".format(path)) checkpoint = _load(path) s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint new_s = {} for k, v in s.items(): if 'low_res' in k: continue else: new_s[k.replace('module.', '')] = v model.load_state_dict(new_s, strict=False) return model def load_network(LNet_path,ENet_path): L_net = LNet() L_net = load_checkpoint(LNet_path, L_net) E_net = ENet(lnet=L_net) model = load_checkpoint(ENet_path, E_net) return model.eval() def load_DNet(DNet_path): D_Net = DNet() print("Load checkpoint from: {}".format(DNet_path)) checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage) D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False) return D_Net.eval()