mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- api.py +13 -11
- model/backbones/dit.py +1 -1
- model/backbones/unett.py +1 -1
- model/utils_infer.py +27 -23
api.py
CHANGED
@@ -69,6 +69,10 @@ class F5TTS:
|
|
69 |
ref_file,
|
70 |
ref_text,
|
71 |
gen_text,
|
|
|
|
|
|
|
|
|
72 |
sway_sampling_coef=-1,
|
73 |
cfg_strength=2,
|
74 |
nfe_step=32,
|
@@ -77,23 +81,21 @@ class F5TTS:
|
|
77 |
remove_silence=False,
|
78 |
file_wave=None,
|
79 |
file_spect=None,
|
80 |
-
cross_fade_duration=0.15,
|
81 |
-
show_info=print,
|
82 |
-
progress=tqdm,
|
83 |
):
|
84 |
wav, sr, spect = infer_process(
|
85 |
ref_file,
|
86 |
ref_text,
|
87 |
gen_text,
|
88 |
self.ema_model,
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
nfe_step,
|
94 |
-
cfg_strength,
|
95 |
-
sway_sampling_coef,
|
96 |
-
|
|
|
97 |
)
|
98 |
|
99 |
if file_wave is not None:
|
|
|
69 |
ref_file,
|
70 |
ref_text,
|
71 |
gen_text,
|
72 |
+
show_info=print,
|
73 |
+
progress=tqdm,
|
74 |
+
target_rms=0.1,
|
75 |
+
cross_fade_duration=0.15,
|
76 |
sway_sampling_coef=-1,
|
77 |
cfg_strength=2,
|
78 |
nfe_step=32,
|
|
|
81 |
remove_silence=False,
|
82 |
file_wave=None,
|
83 |
file_spect=None,
|
|
|
|
|
|
|
84 |
):
|
85 |
wav, sr, spect = infer_process(
|
86 |
ref_file,
|
87 |
ref_text,
|
88 |
gen_text,
|
89 |
self.ema_model,
|
90 |
+
show_info=show_info,
|
91 |
+
progress=progress,
|
92 |
+
target_rms=target_rms,
|
93 |
+
cross_fade_duration=cross_fade_duration,
|
94 |
+
nfe_step=nfe_step,
|
95 |
+
cfg_strength=cfg_strength,
|
96 |
+
sway_sampling_coef=sway_sampling_coef,
|
97 |
+
speed=speed,
|
98 |
+
fix_duration=fix_duration,
|
99 |
)
|
100 |
|
101 |
if file_wave is not None:
|
model/backbones/dit.py
CHANGED
@@ -45,9 +45,9 @@ class TextEmbedding(nn.Module):
|
|
45 |
self.extra_modeling = False
|
46 |
|
47 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
48 |
-
batch, text_len = text.shape[0], text.shape[1]
|
49 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
50 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
|
|
51 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
52 |
|
53 |
if drop_text: # cfg for text
|
|
|
45 |
self.extra_modeling = False
|
46 |
|
47 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
|
|
48 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
49 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
50 |
+
batch, text_len = text.shape[0], text.shape[1]
|
51 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
52 |
|
53 |
if drop_text: # cfg for text
|
model/backbones/unett.py
CHANGED
@@ -48,9 +48,9 @@ class TextEmbedding(nn.Module):
|
|
48 |
self.extra_modeling = False
|
49 |
|
50 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
51 |
-
batch, text_len = text.shape[0], text.shape[1]
|
52 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
53 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
|
|
54 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
55 |
|
56 |
if drop_text: # cfg for text
|
|
|
48 |
self.extra_modeling = False
|
49 |
|
50 |
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
|
|
51 |
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
52 |
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
53 |
+
batch, text_len = text.shape[0], text.shape[1]
|
54 |
text = F.pad(text, (0, seq_len - text_len), value=0)
|
55 |
|
56 |
if drop_text: # cfg for text
|
model/utils_infer.py
CHANGED
@@ -31,12 +31,13 @@ target_sample_rate = 24000
|
|
31 |
n_mel_channels = 100
|
32 |
hop_length = 256
|
33 |
target_rms = 0.1
|
34 |
-
|
35 |
-
|
36 |
-
#
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
# -----------------------------------------
|
42 |
|
@@ -107,7 +108,7 @@ def initialize_asr_pipeline(device=device):
|
|
107 |
# load model for inference
|
108 |
|
109 |
|
110 |
-
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=
|
111 |
if vocab_file == "":
|
112 |
vocab_file = "Emilia_ZH_EN"
|
113 |
tokenizer = "pinyin"
|
@@ -192,14 +193,15 @@ def infer_process(
|
|
192 |
ref_text,
|
193 |
gen_text,
|
194 |
model_obj,
|
195 |
-
cross_fade_duration=0.15,
|
196 |
-
speed=1.0,
|
197 |
show_info=print,
|
198 |
progress=tqdm,
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
203 |
):
|
204 |
# Split the input text into batches
|
205 |
audio, sr = torchaudio.load(ref_audio)
|
@@ -214,13 +216,14 @@ def infer_process(
|
|
214 |
ref_text,
|
215 |
gen_text_batches,
|
216 |
model_obj,
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
nfe_step,
|
221 |
-
cfg_strength,
|
222 |
-
sway_sampling_coef,
|
223 |
-
|
|
|
224 |
)
|
225 |
|
226 |
|
@@ -232,12 +235,13 @@ def infer_batch_process(
|
|
232 |
ref_text,
|
233 |
gen_text_batches,
|
234 |
model_obj,
|
235 |
-
cross_fade_duration=0.15,
|
236 |
-
speed=1,
|
237 |
progress=tqdm,
|
|
|
|
|
238 |
nfe_step=32,
|
239 |
cfg_strength=2.0,
|
240 |
sway_sampling_coef=-1,
|
|
|
241 |
fix_duration=None,
|
242 |
):
|
243 |
audio, sr = ref_audio
|
@@ -262,11 +266,11 @@ def infer_batch_process(
|
|
262 |
text_list = [ref_text + gen_text]
|
263 |
final_text_list = convert_char_to_pinyin(text_list)
|
264 |
|
|
|
265 |
if fix_duration is not None:
|
266 |
duration = int(fix_duration * target_sample_rate / hop_length)
|
267 |
else:
|
268 |
# Calculate duration
|
269 |
-
ref_audio_len = audio.shape[-1] // hop_length
|
270 |
ref_text_len = len(ref_text.encode("utf-8"))
|
271 |
gen_text_len = len(gen_text.encode("utf-8"))
|
272 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
|
|
31 |
n_mel_channels = 100
|
32 |
hop_length = 256
|
33 |
target_rms = 0.1
|
34 |
+
cross_fade_duration = 0.15
|
35 |
+
ode_method = "euler"
|
36 |
+
nfe_step = 32 # 16, 32
|
37 |
+
cfg_strength = 2.0
|
38 |
+
sway_sampling_coef = -1.0
|
39 |
+
speed = 1.0
|
40 |
+
fix_duration = None
|
41 |
|
42 |
# -----------------------------------------
|
43 |
|
|
|
108 |
# load model for inference
|
109 |
|
110 |
|
111 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
|
112 |
if vocab_file == "":
|
113 |
vocab_file = "Emilia_ZH_EN"
|
114 |
tokenizer = "pinyin"
|
|
|
193 |
ref_text,
|
194 |
gen_text,
|
195 |
model_obj,
|
|
|
|
|
196 |
show_info=print,
|
197 |
progress=tqdm,
|
198 |
+
target_rms=target_rms,
|
199 |
+
cross_fade_duration=cross_fade_duration,
|
200 |
+
nfe_step=nfe_step,
|
201 |
+
cfg_strength=cfg_strength,
|
202 |
+
sway_sampling_coef=sway_sampling_coef,
|
203 |
+
speed=speed,
|
204 |
+
fix_duration=fix_duration,
|
205 |
):
|
206 |
# Split the input text into batches
|
207 |
audio, sr = torchaudio.load(ref_audio)
|
|
|
216 |
ref_text,
|
217 |
gen_text_batches,
|
218 |
model_obj,
|
219 |
+
progress=progress,
|
220 |
+
target_rms=target_rms,
|
221 |
+
cross_fade_duration=cross_fade_duration,
|
222 |
+
nfe_step=nfe_step,
|
223 |
+
cfg_strength=cfg_strength,
|
224 |
+
sway_sampling_coef=sway_sampling_coef,
|
225 |
+
speed=speed,
|
226 |
+
fix_duration=fix_duration,
|
227 |
)
|
228 |
|
229 |
|
|
|
235 |
ref_text,
|
236 |
gen_text_batches,
|
237 |
model_obj,
|
|
|
|
|
238 |
progress=tqdm,
|
239 |
+
target_rms=0.1,
|
240 |
+
cross_fade_duration=0.15,
|
241 |
nfe_step=32,
|
242 |
cfg_strength=2.0,
|
243 |
sway_sampling_coef=-1,
|
244 |
+
speed=1,
|
245 |
fix_duration=None,
|
246 |
):
|
247 |
audio, sr = ref_audio
|
|
|
266 |
text_list = [ref_text + gen_text]
|
267 |
final_text_list = convert_char_to_pinyin(text_list)
|
268 |
|
269 |
+
ref_audio_len = audio.shape[-1] // hop_length
|
270 |
if fix_duration is not None:
|
271 |
duration = int(fix_duration * target_sample_rate / hop_length)
|
272 |
else:
|
273 |
# Calculate duration
|
|
|
274 |
ref_text_len = len(ref_text.encode("utf-8"))
|
275 |
gen_text_len = len(gen_text.encode("utf-8"))
|
276 |
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|