SWivid commited on
Commit
520692d
1 Parent(s): b53eca8

Update app.py

Browse files

Mainly redirect to split ckpt repos, minor updates.
fix "gen_text" -> "chunk"

Files changed (1) hide show
  1. app.py +35 -18
app.py CHANGED
@@ -8,7 +8,7 @@ import tempfile
8
  from einops import rearrange
9
  from ema_pytorch import EMA
10
  from vocos import Vocos
11
- from pydub import AudioSegment
12
  from model import CFM, UNetT, DiT, MMDiT
13
  from cached_path import cached_path
14
  from model.utils import (
@@ -19,6 +19,7 @@ from model.utils import (
19
  from transformers import pipeline
20
  import spaces
21
  import librosa
 
22
  from txtsplit import txtsplit
23
  from detoxify import Detoxify
24
 
@@ -49,8 +50,8 @@ speed = 1.0
49
  # fix_duration = 27 # None or float (duration in seconds)
50
  fix_duration = None
51
 
52
- def load_model(exp_name, model_cls, model_cfg, ckpt_step):
53
- checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
54
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
55
  model = CFM(
56
  transformer=model_cls(
@@ -73,14 +74,14 @@ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
73
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
74
  ema_model.copy_params_from_ema_to_model()
75
 
76
- return ema_model, model
77
 
78
  # load models
79
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
81
 
82
- F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
83
- E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
84
 
85
  @spaces.GPU
86
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
@@ -91,6 +92,12 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
91
  gr.Info("Converting audio...")
92
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
93
  aseg = AudioSegment.from_file(ref_audio_orig)
 
 
 
 
 
 
94
  # Convert to mono
95
  aseg = aseg.set_channels(1)
96
  audio_duration = len(aseg)
@@ -101,10 +108,8 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
101
  ref_audio = f.name
102
  if exp_name == "F5-TTS":
103
  ema_model = F5TTS_ema_model
104
- base_model = F5TTS_base_model
105
  elif exp_name == "E2-TTS":
106
  ema_model = E2TTS_ema_model
107
- base_model = E2TTS_base_model
108
 
109
  if not ref_text.strip():
110
  gr.Info("No reference text provided, transcribing reference audio...")
@@ -119,6 +124,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
119
  else:
120
  gr.Info("Using custom reference text...")
121
  audio, sr = torchaudio.load(ref_audio)
 
122
  # Audio
123
  if audio.shape[0] > 1:
124
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -130,7 +136,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
130
  audio = resampler(audio)
131
  audio = audio.to(device)
132
  # Chunk
133
- chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
134
  results = []
135
  generated_mel_specs = []
136
  for chunk in progress.tqdm(chunks):
@@ -144,14 +150,14 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
144
  # duration = int(fix_duration * target_sample_rate / hop_length)
145
  # else:
146
  zh_pause_punc = r"。,、;:?!"
147
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
148
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
149
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
150
 
151
  # inference
152
  gr.Info(f"Generating audio using {exp_name}")
153
  with torch.inference_mode():
154
- generated, _ = base_model.sample(
155
  cond=audio,
156
  text=final_text_list,
157
  duration=duration,
@@ -174,12 +180,23 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress
174
  generated_wave = np.concatenate(results)
175
  if remove_silence:
176
  gr.Info("Removing audio silences... This may take a moment")
177
- non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
178
- non_silent_wave = np.array([])
179
- for interval in non_silent_intervals:
180
- start, end = interval
181
- non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
182
- generated_wave = non_silent_wave
 
 
 
 
 
 
 
 
 
 
 
183
 
184
 
185
  # spectogram
 
8
  from einops import rearrange
9
  from ema_pytorch import EMA
10
  from vocos import Vocos
11
+ from pydub import AudioSegment, silence
12
  from model import CFM, UNetT, DiT, MMDiT
13
  from cached_path import cached_path
14
  from model.utils import (
 
19
  from transformers import pipeline
20
  import spaces
21
  import librosa
22
+ import soundfile as sf
23
  from txtsplit import txtsplit
24
  from detoxify import Detoxify
25
 
 
50
  # fix_duration = 27 # None or float (duration in seconds)
51
  fix_duration = None
52
 
53
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
54
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
55
  vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
56
  model = CFM(
57
  transformer=model_cls(
 
74
  ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
75
  ema_model.copy_params_from_ema_to_model()
76
 
77
+ return model
78
 
79
  # load models
80
  F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
81
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
82
 
83
+ F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
84
+ E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
85
 
86
  @spaces.GPU
87
  def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
 
92
  gr.Info("Converting audio...")
93
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
94
  aseg = AudioSegment.from_file(ref_audio_orig)
95
+ # remove long silence in reference audio
96
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
97
+ non_silent_wave = AudioSegment.silent(duration=0)
98
+ for non_silent_seg in non_silent_segs:
99
+ non_silent_wave += non_silent_seg
100
+ aseg = non_silent_wave
101
  # Convert to mono
102
  aseg = aseg.set_channels(1)
103
  audio_duration = len(aseg)
 
108
  ref_audio = f.name
109
  if exp_name == "F5-TTS":
110
  ema_model = F5TTS_ema_model
 
111
  elif exp_name == "E2-TTS":
112
  ema_model = E2TTS_ema_model
 
113
 
114
  if not ref_text.strip():
115
  gr.Info("No reference text provided, transcribing reference audio...")
 
124
  else:
125
  gr.Info("Using custom reference text...")
126
  audio, sr = torchaudio.load(ref_audio)
127
+ max_chars = int(len(ref_text) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
128
  # Audio
129
  if audio.shape[0] > 1:
130
  audio = torch.mean(audio, dim=0, keepdim=True)
 
136
  audio = resampler(audio)
137
  audio = audio.to(device)
138
  # Chunk
139
+ chunks = txtsplit(gen_text, 0.7*max_chars, 0.9*max_chars)
140
  results = []
141
  generated_mel_specs = []
142
  for chunk in progress.tqdm(chunks):
 
150
  # duration = int(fix_duration * target_sample_rate / hop_length)
151
  # else:
152
  zh_pause_punc = r"。,、;:?!"
153
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
154
+ chunk = len(chunk.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
155
  duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
156
 
157
  # inference
158
  gr.Info(f"Generating audio using {exp_name}")
159
  with torch.inference_mode():
160
+ generated, _ = ema_model.sample(
161
  cond=audio,
162
  text=final_text_list,
163
  duration=duration,
 
180
  generated_wave = np.concatenate(results)
181
  if remove_silence:
182
  gr.Info("Removing audio silences... This may take a moment")
183
+ # non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
184
+ # non_silent_wave = np.array([])
185
+ # for interval in non_silent_intervals:
186
+ # start, end = interval
187
+ # non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
188
+ # generated_wave = non_silent_wave
189
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
190
+ sf.write(f.name, generated_wave, target_sample_rate)
191
+ aseg = AudioSegment.from_file(f.name)
192
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
193
+ non_silent_wave = AudioSegment.silent(duration=0)
194
+ for non_silent_seg in non_silent_segs:
195
+ non_silent_wave += non_silent_seg
196
+ aseg = non_silent_wave
197
+ aseg.export(f.name, format="wav")
198
+ generated_wave, _ = torchaudio.load(f.name)
199
+ generated_wave = generated_wave.squeeze().cpu().numpy()
200
 
201
 
202
  # spectogram