Spanicin commited on
Commit
fc094e6
1 Parent(s): 0c4a490

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +3 -3
src/facerender/animate.py CHANGED
@@ -27,7 +27,7 @@ from src.utils.videoio import save_video_with_watermark
27
  class AnimateFromCoeff():
28
 
29
  def __init__(self, free_view_checkpoint, mapping_checkpoint,
30
- config_path, device):
31
 
32
  with open(config_path) as f:
33
  config = yaml.safe_load(f)
@@ -88,7 +88,7 @@ class AnimateFromCoeff():
88
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
89
  kp_detector=None, he_estimator=None, optimizer_generator=None,
90
  optimizer_discriminator=None, optimizer_kp_detector=None,
91
- optimizer_he_estimator=None, device="cpu"):
92
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
93
 
94
  def adjust_state_dict(state_dict, model):
@@ -135,7 +135,7 @@ class AnimateFromCoeff():
135
  return checkpoint['epoch']
136
 
137
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
138
- optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
139
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
140
 
141
  def adjust_state_dict(state_dict, model):
 
27
  class AnimateFromCoeff():
28
 
29
  def __init__(self, free_view_checkpoint, mapping_checkpoint,
30
+ config_path, device='cuda'):
31
 
32
  with open(config_path) as f:
33
  config = yaml.safe_load(f)
 
88
  def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
89
  kp_detector=None, he_estimator=None, optimizer_generator=None,
90
  optimizer_discriminator=None, optimizer_kp_detector=None,
91
+ optimizer_he_estimator=None, device="cuda"):
92
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
93
 
94
  def adjust_state_dict(state_dict, model):
 
135
  return checkpoint['epoch']
136
 
137
  def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
138
+ optimizer_mapping=None, optimizer_discriminator=None, device='cuda'):
139
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
140
 
141
  def adjust_state_dict(state_dict, model):