mrfakename commited on
Commit
61075cd
·
verified ·
1 Parent(s): 5659999

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +5 -3
src/f5_tts/model/trainer.py CHANGED
@@ -167,14 +167,16 @@ class Trainer:
167
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
  if key in checkpoint["ema_model_state_dict"]:
169
  del checkpoint["ema_model_state_dict"][key]
170
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
171
- if key in checkpoint["model_state_dict"]:
172
- del checkpoint["model_state_dict"][key]
173
 
174
  if self.is_main:
175
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
176
 
177
  if "step" in checkpoint:
 
 
 
 
 
178
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
179
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
180
  if self.scheduler:
 
167
  for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
  if key in checkpoint["ema_model_state_dict"]:
169
  del checkpoint["ema_model_state_dict"][key]
 
 
 
170
 
171
  if self.is_main:
172
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
173
 
174
  if "step" in checkpoint:
175
+ # patch for backward compatibility, 305e3ea
176
+ for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
177
+ if key in checkpoint["model_state_dict"]:
178
+ del checkpoint["model_state_dict"][key]
179
+
180
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
181
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
182
  if self.scheduler: