mrfakename commited on
Commit
5659999
·
verified ·
1 Parent(s): 5a9adbc

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/infer/utils_infer.py CHANGED
@@ -156,6 +156,7 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
156
  if k not in ["initted", "step"]
157
  }
158
 
 
159
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
160
  if key in checkpoint["model_state_dict"]:
161
  del checkpoint["model_state_dict"][key]
 
156
  if k not in ["initted", "step"]
157
  }
158
 
159
+ # patch for backward compatibility, 305e3ea
160
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
161
  if key in checkpoint["model_state_dict"]:
162
  del checkpoint["model_state_dict"][key]
src/f5_tts/model/trainer.py CHANGED
@@ -163,6 +163,14 @@ class Trainer:
163
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
 
 
 
 
 
 
 
 
 
166
  if self.is_main:
167
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
168
 
 
163
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
 
166
+ # patch for backward compatibility, 305e3ea
167
+ for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
+ if key in checkpoint["ema_model_state_dict"]:
169
+ del checkpoint["ema_model_state_dict"][key]
170
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
171
+ if key in checkpoint["model_state_dict"]:
172
+ del checkpoint["model_state_dict"][key]
173
+
174
  if self.is_main:
175
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
176