Spanicin commited on
Commit
4b96359
1 Parent(s): 784616c

Update videoretalking/models/__init__.py

Browse files
Files changed (1) hide show
  1. videoretalking/models/__init__.py +36 -36
videoretalking/models/__init__.py CHANGED
@@ -1,37 +1,37 @@
1
- import torch
2
- from models.DNet import DNet
3
- from models.LNet import LNet
4
- from models.ENet import ENet
5
-
6
-
7
- def _load(checkpoint_path):
8
- map_location=None if torch.cuda.is_available() else torch.device('cpu')
9
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
10
- return checkpoint
11
-
12
- def load_checkpoint(path, model):
13
- print("Load checkpoint from: {}".format(path))
14
- checkpoint = _load(path)
15
- s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
16
- new_s = {}
17
- for k, v in s.items():
18
- if 'low_res' in k:
19
- continue
20
- else:
21
- new_s[k.replace('module.', '')] = v
22
- model.load_state_dict(new_s, strict=False)
23
- return model
24
-
25
- def load_network(LNet_path,ENet_path):
26
- L_net = LNet()
27
- L_net = load_checkpoint(LNet_path, L_net)
28
- E_net = ENet(lnet=L_net)
29
- model = load_checkpoint(ENet_path, E_net)
30
- return model.eval()
31
-
32
- def load_DNet(DNet_path):
33
- D_Net = DNet()
34
- print("Load checkpoint from: {}".format(DNet_path))
35
- checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
36
- D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
37
  return D_Net.eval()
 
1
+ import torch
2
+ from videoretalking.models.DNet import DNet
3
+ from videoretalking.models.LNet import LNet
4
+ from videoretalking.models.ENet import ENet
5
+
6
+
7
+ def _load(checkpoint_path):
8
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
9
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
10
+ return checkpoint
11
+
12
+ def load_checkpoint(path, model):
13
+ print("Load checkpoint from: {}".format(path))
14
+ checkpoint = _load(path)
15
+ s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint
16
+ new_s = {}
17
+ for k, v in s.items():
18
+ if 'low_res' in k:
19
+ continue
20
+ else:
21
+ new_s[k.replace('module.', '')] = v
22
+ model.load_state_dict(new_s, strict=False)
23
+ return model
24
+
25
+ def load_network(LNet_path,ENet_path):
26
+ L_net = LNet()
27
+ L_net = load_checkpoint(LNet_path, L_net)
28
+ E_net = ENet(lnet=L_net)
29
+ model = load_checkpoint(ENet_path, E_net)
30
+ return model.eval()
31
+
32
+ def load_DNet(DNet_path):
33
+ D_Net = DNet()
34
+ print("Load checkpoint from: {}".format(DNet_path))
35
+ checkpoint = torch.load(DNet_path, map_location=lambda storage, loc: storage)
36
+ D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False)
37
  return D_Net.eval()