mrfakename
commited on
Commit
•
35005eb
1
Parent(s):
2669b3f
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/utils_infer.py +25 -7
model/utils_infer.py
CHANGED
@@ -19,8 +19,14 @@ from model.utils import (
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
26 |
|
@@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):
|
|
76 |
|
77 |
|
78 |
# load vocoder
|
79 |
-
def load_vocoder(is_local=False, local_path="", device=
|
|
|
|
|
80 |
if is_local:
|
81 |
print(f"Load vocos from local path {local_path}")
|
82 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
@@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
|
|
94 |
asr_pipe = None
|
95 |
|
96 |
|
97 |
-
def initialize_asr_pipeline(device=
|
98 |
global asr_pipe
|
|
|
|
|
99 |
|
100 |
asr_pipe = pipeline(
|
101 |
"automatic-speech-recognition",
|
@@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
|
|
108 |
# load model for inference
|
109 |
|
110 |
|
111 |
-
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=
|
|
|
|
|
112 |
if vocab_file == "":
|
113 |
vocab_file = "Emilia_ZH_EN"
|
114 |
tokenizer = "pinyin"
|
@@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
|
|
141 |
# preprocess reference audio and text
|
142 |
|
143 |
|
144 |
-
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
|
|
|
|
|
145 |
show_info("Converting audio...")
|
146 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
147 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
@@ -243,7 +257,11 @@ def infer_batch_process(
|
|
243 |
sway_sampling_coef=-1,
|
244 |
speed=1,
|
245 |
fix_duration=None,
|
|
|
246 |
):
|
|
|
|
|
|
|
247 |
audio, sr = ref_audio
|
248 |
if audio.shape[0] > 1:
|
249 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
@@ -254,7 +272,7 @@ def infer_batch_process(
|
|
254 |
if sr != target_sample_rate:
|
255 |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
256 |
audio = resampler(audio)
|
257 |
-
audio = audio.to(
|
258 |
|
259 |
generated_waves = []
|
260 |
spectrograms = []
|
|
|
19 |
convert_char_to_pinyin,
|
20 |
)
|
21 |
|
22 |
+
# get device
|
23 |
+
|
24 |
+
|
25 |
+
def get_device():
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
27 |
+
# print(f"Using {device} device")
|
28 |
+
return device
|
29 |
+
|
30 |
|
31 |
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
32 |
|
|
|
82 |
|
83 |
|
84 |
# load vocoder
|
85 |
+
def load_vocoder(is_local=False, local_path="", device=None):
|
86 |
+
if device is None:
|
87 |
+
device = get_device()
|
88 |
if is_local:
|
89 |
print(f"Load vocos from local path {local_path}")
|
90 |
vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
|
|
|
102 |
asr_pipe = None
|
103 |
|
104 |
|
105 |
+
def initialize_asr_pipeline(device=None):
|
106 |
global asr_pipe
|
107 |
+
if device is None:
|
108 |
+
device = get_device()
|
109 |
|
110 |
asr_pipe = pipeline(
|
111 |
"automatic-speech-recognition",
|
|
|
118 |
# load model for inference
|
119 |
|
120 |
|
121 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
|
122 |
+
if device is None:
|
123 |
+
device = get_device()
|
124 |
if vocab_file == "":
|
125 |
vocab_file = "Emilia_ZH_EN"
|
126 |
tokenizer = "pinyin"
|
|
|
153 |
# preprocess reference audio and text
|
154 |
|
155 |
|
156 |
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
|
157 |
+
device = get_device(device)
|
158 |
+
|
159 |
show_info("Converting audio...")
|
160 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
161 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
|
257 |
sway_sampling_coef=-1,
|
258 |
speed=1,
|
259 |
fix_duration=None,
|
260 |
+
device=None,
|
261 |
):
|
262 |
+
if device is None:
|
263 |
+
device = get_device()
|
264 |
+
|
265 |
audio, sr = ref_audio
|
266 |
if audio.shape[0] > 1:
|
267 |
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
272 |
if sr != target_sample_rate:
|
273 |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
274 |
audio = resampler(audio)
|
275 |
+
audio = audio.to()
|
276 |
|
277 |
generated_waves = []
|
278 |
spectrograms = []
|