Gregniuki commited on
Commit
218ae0c
1 Parent(s): eda0fb4

Upload 6 files

Browse files
Files changed (5) hide show
  1. model/__init__.py +8 -5
  2. model/cfm.py +3 -3
  3. model/dataset.py +22 -17
  4. model/trainer.py +30 -17
  5. model/utils.py +54 -467
model/__init__.py CHANGED
@@ -1,7 +1,10 @@
1
- from model.cfm import CFM
2
 
3
- from model.backbones.unett import UNetT
4
- from model.backbones.dit import DiT
5
- from model.backbones.mmdit import MMDiT
6
 
7
- from model.trainer import Trainer
 
 
 
 
1
+ from f5_tts.model.cfm import CFM
2
 
3
+ from f5_tts.model.backbones.unett import UNetT
4
+ from f5_tts.model.backbones.dit import DiT
5
+ from f5_tts.model.backbones.mmdit import MMDiT
6
 
7
+ from f5_tts.model.trainer import Trainer
8
+
9
+
10
+ __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
model/cfm.py CHANGED
@@ -18,8 +18,8 @@ from torch import nn
18
  from torch.nn.utils.rnn import pad_sequence
19
  from torchdiffeq import odeint
20
 
21
- from model.modules import MelSpec
22
- from model.utils import (
23
  default,
24
  exists,
25
  lens_to_mask,
@@ -193,7 +193,7 @@ class CFM(nn.Module):
193
  y0 = (1 - t_start) * y0 + t_start * test_cond
194
  steps = int(steps * (1 - t_start))
195
 
196
- t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
197
  if sway_sampling_coef is not None:
198
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199
 
 
18
  from torch.nn.utils.rnn import pad_sequence
19
  from torchdiffeq import odeint
20
 
21
+ from f5_tts.model.modules import MelSpec
22
+ from f5_tts.model.utils import (
23
  default,
24
  exists,
25
  lens_to_mask,
 
193
  y0 = (1 - t_start) * y0 + t_start * test_cond
194
  steps = int(steps * (1 - t_start))
195
 
196
+ t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
197
  if sway_sampling_coef is not None:
198
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
199
 
model/dataset.py CHANGED
@@ -11,8 +11,8 @@ from torch import nn
11
  from torch.utils.data import Dataset, Sampler
12
  from tqdm import tqdm
13
 
14
- from model.modules import MelSpec
15
- from model.utils import default
16
 
17
 
18
  class HFDataset(Dataset):
@@ -127,38 +127,43 @@ class CustomDataset(Dataset):
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
- row = self.data[index]
131
- audio_path = row["audio_path"]
132
- text = row["text"]
133
- duration = row["duration"]
 
 
 
 
 
 
 
134
 
135
  if self.preprocessed_mel:
136
  mel_spec = torch.tensor(row["mel_spec"])
137
-
138
  else:
139
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
140
  if audio.shape[0] > 1:
141
  audio = torch.mean(audio, dim=0, keepdim=True)
142
 
143
- if duration > 30 or duration < 0.3:
144
- return self.__getitem__((index + 1) % len(self.data))
145
-
146
  if source_sample_rate != self.target_sample_rate:
147
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148
  audio = resampler(audio)
149
 
 
150
  mel_spec = self.mel_spectrogram(audio)
151
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152
 
153
- return dict(
154
- mel_spec=mel_spec,
155
- text=text,
156
- )
157
 
158
 
159
  # Dynamic Batch Sampler
160
-
161
-
162
  class DynamicBatchSampler(Sampler[list[int]]):
163
  """Extension of Sampler that will do the following:
164
  1. Change the batch size (essentially number of sequences)
 
11
  from torch.utils.data import Dataset, Sampler
12
  from tqdm import tqdm
13
 
14
+ from f5_tts.model.modules import MelSpec
15
+ from f5_tts.model.utils import default
16
 
17
 
18
  class HFDataset(Dataset):
 
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
+ while True:
131
+ row = self.data[index]
132
+ audio_path = row["audio_path"]
133
+ text = row["text"]
134
+ duration = row["duration"]
135
+
136
+ # filter by given length
137
+ if 0.3 <= duration <= 30:
138
+ break # valid
139
+
140
+ index = (index + 1) % len(self.data)
141
 
142
  if self.preprocessed_mel:
143
  mel_spec = torch.tensor(row["mel_spec"])
 
144
  else:
145
  audio, source_sample_rate = torchaudio.load(audio_path)
146
+
147
+ # make sure mono input
148
  if audio.shape[0] > 1:
149
  audio = torch.mean(audio, dim=0, keepdim=True)
150
 
151
+ # resample if necessary
 
 
152
  if source_sample_rate != self.target_sample_rate:
153
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
154
  audio = resampler(audio)
155
 
156
+ # to mel spectrogram
157
  mel_spec = self.mel_spectrogram(audio)
158
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
159
 
160
+ return {
161
+ "mel_spec": mel_spec,
162
+ "text": text,
163
+ }
164
 
165
 
166
  # Dynamic Batch Sampler
 
 
167
  class DynamicBatchSampler(Sampler[list[int]]):
168
  """Extension of Sampler that will do the following:
169
  1. Change the batch size (essentially number of sequences)
model/trainer.py CHANGED
@@ -14,9 +14,9 @@ from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
  from tqdm import tqdm
16
 
17
- from model import CFM
18
- from model.dataset import DynamicBatchSampler, collate_fn
19
- from model.utils import default, exists
20
 
21
  # trainer
22
 
@@ -47,6 +47,8 @@ class Trainer:
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
 
 
50
  ):
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52
 
@@ -108,7 +110,11 @@ class Trainer:
108
  self.max_samples = max_samples
109
  self.grad_accumulation_steps = grad_accumulation_steps
110
  self.max_grad_norm = max_grad_norm
 
 
111
  self.vocoder_name = mel_spec_type
 
 
112
 
113
  self.noise_scheduler = noise_scheduler
114
 
@@ -148,7 +154,7 @@ class Trainer:
148
  if (
149
  not exists(self.checkpoint_path)
150
  or not os.path.exists(self.checkpoint_path)
151
- or not os.listdir(self.checkpoint_path)
152
  ):
153
  return 0
154
 
@@ -199,7 +205,9 @@ class Trainer:
199
  if self.log_samples:
200
  from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
201
 
202
- vocoder = load_vocoder(vocoder_name=self.vocoder_name)
 
 
203
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
204
  log_samples_path = f"{self.checkpoint_path}/samples"
205
  os.makedirs(log_samples_path, exist_ok=True)
@@ -324,26 +332,31 @@ class Trainer:
324
  self.save_checkpoint(global_step)
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
- ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(
329
- f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330
- )
331
  with torch.inference_mode():
332
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334
- text=[text_inputs[0] + [" "] + text_inputs[0]],
335
  duration=ref_audio_len * 2,
336
  steps=nfe_step,
337
  cfg_strength=cfg_strength,
338
  sway_sampling_coef=sway_sampling_coef,
339
  )
340
- generated = generated.to(torch.float32)
341
- gen_audio = vocoder.decode(
342
- generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343
- )
344
- torchaudio.save(
345
- f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346
- )
 
 
 
 
 
347
 
348
  if global_step % self.last_per_steps == 0:
349
  self.save_checkpoint(global_step, last=True)
 
14
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
  from tqdm import tqdm
16
 
17
+ from f5_tts.model import CFM
18
+ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
+ from f5_tts.model.utils import default, exists
20
 
21
  # trainer
22
 
 
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
+ is_local_vocoder: bool = False, # use local path vocoder
51
+ local_vocoder_path: str = "", # local vocoder path
52
  ):
53
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
 
 
110
  self.max_samples = max_samples
111
  self.grad_accumulation_steps = grad_accumulation_steps
112
  self.max_grad_norm = max_grad_norm
113
+
114
+ # mel vocoder config
115
  self.vocoder_name = mel_spec_type
116
+ self.is_local_vocoder = is_local_vocoder
117
+ self.local_vocoder_path = local_vocoder_path
118
 
119
  self.noise_scheduler = noise_scheduler
120
 
 
154
  if (
155
  not exists(self.checkpoint_path)
156
  or not os.path.exists(self.checkpoint_path)
157
+ or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
158
  ):
159
  return 0
160
 
 
205
  if self.log_samples:
206
  from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
207
 
208
+ vocoder = load_vocoder(
209
+ vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
210
+ )
211
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
212
  log_samples_path = f"{self.checkpoint_path}/samples"
213
  os.makedirs(log_samples_path, exist_ok=True)
 
332
  self.save_checkpoint(global_step)
333
 
334
  if self.log_samples and self.accelerator.is_local_main_process:
335
+ ref_audio_len = mel_lengths[0]
336
+ infer_text = [
337
+ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
338
+ ]
339
  with torch.inference_mode():
340
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
341
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
342
+ text=infer_text,
343
  duration=ref_audio_len * 2,
344
  steps=nfe_step,
345
  cfg_strength=cfg_strength,
346
  sway_sampling_coef=sway_sampling_coef,
347
  )
348
+ generated = generated.to(torch.float32)
349
+ gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
350
+ ref_mel_spec = batch["mel"][0].unsqueeze(0)
351
+ if self.vocoder_name == "vocos":
352
+ gen_audio = vocoder.decode(gen_mel_spec).cpu()
353
+ ref_audio = vocoder.decode(ref_mel_spec).cpu()
354
+ elif self.vocoder_name == "bigvgan":
355
+ gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
356
+ ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
357
+
358
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
359
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
360
 
361
  if global_step % self.last_per_steps == 0:
362
  self.save_checkpoint(global_step, last=True)
model/utils.py CHANGED
@@ -1,135 +1,116 @@
1
  from __future__ import annotations
2
 
3
  import os
4
- import re
5
- import math
6
  import random
7
- import string
8
- from tqdm import tqdm
9
  from collections import defaultdict
10
-
11
- import matplotlib
12
- matplotlib.use("Agg")
13
- import matplotlib.pylab as plt
14
 
15
  import torch
16
- import torch.nn.functional as F
17
  from torch.nn.utils.rnn import pad_sequence
18
- import torchaudio
19
-
20
- import einx
21
- from einops import rearrange, reduce
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
 
26
- from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
- from model.modules import MelSpec
28
-
29
 
30
  # seed everything
31
 
32
- def seed_everything(seed = 0):
 
33
  random.seed(seed)
34
- os.environ['PYTHONHASHSEED'] = str(seed)
35
  torch.manual_seed(seed)
36
  torch.cuda.manual_seed(seed)
37
  torch.cuda.manual_seed_all(seed)
38
  torch.backends.cudnn.deterministic = True
39
  torch.backends.cudnn.benchmark = False
40
 
 
41
  # helpers
42
 
 
43
  def exists(v):
44
  return v is not None
45
 
 
46
  def default(v, d):
47
  return v if exists(v) else d
48
 
 
49
  # tensor helpers
50
 
51
- def lens_to_mask(
52
- t: int['b'],
53
- length: int | None = None
54
- ) -> bool['b n']:
55
 
 
56
  if not exists(length):
57
  length = t.amax()
58
 
59
- seq = torch.arange(length, device = t.device)
60
- return einx.less('n, b -> b n', seq, t)
61
-
62
- def mask_from_start_end_indices(
63
- seq_len: int['b'],
64
- start: int['b'],
65
- end: int['b']
66
- ):
67
- max_seq_len = seq_len.max().item()
68
- seq = torch.arange(max_seq_len, device = start.device).long()
69
- return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
-
71
- def mask_from_frac_lengths(
72
- seq_len: int['b'],
73
- frac_lengths: float['b']
74
- ):
75
  lengths = (frac_lengths * seq_len).long()
76
  max_start = seq_len - lengths
77
 
78
  rand = torch.rand_like(frac_lengths)
79
- start = (max_start * rand).long().clamp(min = 0)
80
  end = start + lengths
81
 
82
  return mask_from_start_end_indices(seq_len, start, end)
83
 
84
- def maybe_masked_mean(
85
- t: float['b n d'],
86
- mask: bool['b n'] = None
87
- ) -> float['b d']:
88
 
 
89
  if not exists(mask):
90
- return t.mean(dim = 1)
91
 
92
- t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
- num = reduce(t, 'b n d -> b d', 'sum')
94
- den = reduce(mask.float(), 'b n -> b', 'sum')
95
 
96
- return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
 
98
 
99
  # simple utf-8 tokenizer, since paper went character based
100
- def list_str_to_tensor(
101
- text: list[str],
102
- padding_value = -1
103
- ) -> int['b nt']:
104
- list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
- text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
  return text
107
 
 
108
  # char tokenizer, based on custom dataset's extracted .txt file
109
  def list_str_to_idx(
110
  text: list[str] | list[list[str]],
111
  vocab_char_map: dict[str, int], # {char: idx}
112
- padding_value = -1
113
- ) -> int['b nt']:
114
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
- text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
  return text
117
 
118
 
119
  # Get tokenizer
120
 
 
121
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
- '''
123
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
  - "char" for char-wise tokenizer, need .txt vocab_file
125
  - "byte" for utf-8 tokenizer
126
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
- - if use "byte", set to 256 (unicode byte range)
130
- '''
131
  if tokenizer in ["pinyin", "char"]:
132
- with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
 
133
  vocab_char_map = {}
134
  for i, char in enumerate(f):
135
  vocab_char_map[char[:-1]] = i
@@ -139,8 +120,9 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
139
  elif tokenizer == "byte":
140
  vocab_char_map = None
141
  vocab_size = 256
 
142
  elif tokenizer == "custom":
143
- with open (dataset_name, "r", encoding="utf-8") as f:
144
  vocab_char_map = {}
145
  for i, char in enumerate(f):
146
  vocab_char_map[char[:-1]] = i
@@ -151,16 +133,19 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
151
 
152
  # convert char to pinyin
153
 
154
- def convert_char_to_pinyin(text_list, polyphone = True):
 
155
  final_text_list = []
156
- god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
- custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
 
 
158
  for text in text_list:
159
  char_list = []
160
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
  text = text.translate(custom_trans)
162
  for seg in jieba.cut(text):
163
- seg_byte_len = len(bytes(seg, 'UTF-8'))
164
  if seg_byte_len == len(seg): # if pure alphabets and symbols
165
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
  char_list.append(" ")
@@ -186,413 +171,15 @@ def convert_char_to_pinyin(text_list, polyphone = True):
186
  return final_text_list
187
 
188
 
189
- # save spectrogram
190
- def save_spectrogram(spectrogram, path):
191
- plt.figure(figsize=(12, 4))
192
- plt.imshow(spectrogram, origin='lower', aspect='auto')
193
- plt.colorbar()
194
- plt.savefig(path)
195
- plt.close()
196
-
197
-
198
- # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
- def get_seedtts_testset_metainfo(metalst):
200
- f = open(metalst); lines = f.readlines(); f.close()
201
- metainfo = []
202
- for line in lines:
203
- if len(line.strip().split('|')) == 5:
204
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
- elif len(line.strip().split('|')) == 4:
206
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
- gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
- if not os.path.isabs(prompt_wav):
209
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
- metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
- return metainfo
212
-
213
-
214
- # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
- def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
- f = open(metalst); lines = f.readlines(); f.close()
217
- metainfo = []
218
- for line in lines:
219
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
-
221
- # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
-
225
- # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
-
229
- metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
-
231
- return metainfo
232
-
233
-
234
- # padded to max length mel batch
235
- def padded_mel_batch(ref_mels):
236
- max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
- padded_ref_mels = []
238
- for mel in ref_mels:
239
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
- padded_ref_mels.append(padded_ref_mel)
241
- padded_ref_mels = torch.stack(padded_ref_mels)
242
- padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
- return padded_ref_mels
244
-
245
-
246
- # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
-
248
- def get_inference_prompt(
249
- metainfo,
250
- speed = 1., tokenizer = "pinyin", polyphone = True,
251
- target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
- use_truth_duration = False,
253
- infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
- ):
255
- prompts_all = []
256
-
257
- min_tokens = min_secs * target_sample_rate // hop_length
258
- max_tokens = max_secs * target_sample_rate // hop_length
259
-
260
- batch_accum = [0] * num_buckets
261
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
- ([[] for _ in range(num_buckets)] for _ in range(6))
263
-
264
- mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
-
266
- for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
-
268
- # Audio
269
- ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
- ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
- if ref_rms < target_rms:
272
- ref_audio = ref_audio * target_rms / ref_rms
273
- assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
- if ref_sr != target_sample_rate:
275
- resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
- ref_audio = resampler(ref_audio)
277
-
278
- # Text
279
- if len(prompt_text[-1].encode('utf-8')) == 1:
280
- prompt_text = prompt_text + " "
281
- text = [prompt_text + gt_text]
282
- if tokenizer == "pinyin":
283
- text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
- else:
285
- text_list = text
286
-
287
- # Duration, mel frame length
288
- ref_mel_len = ref_audio.shape[-1] // hop_length
289
- if use_truth_duration:
290
- gt_audio, gt_sr = torchaudio.load(gt_wav)
291
- if gt_sr != target_sample_rate:
292
- resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
- gt_audio = resampler(gt_audio)
294
- total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
-
296
- # # test vocoder resynthesis
297
- # ref_audio = gt_audio
298
- else:
299
- zh_pause_punc = r"。,、;:?!"
300
- ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
- gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
- total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
-
304
- # to mel spectrogram
305
- ref_mel = mel_spectrogram(ref_audio)
306
- ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
-
308
- # deal with batch
309
- assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
- assert min_tokens <= total_mel_len <= max_tokens, \
311
- f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
- bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
-
314
- utts[bucket_i].append(utt)
315
- ref_rms_list[bucket_i].append(ref_rms)
316
- ref_mels[bucket_i].append(ref_mel)
317
- ref_mel_lens[bucket_i].append(ref_mel_len)
318
- total_mel_lens[bucket_i].append(total_mel_len)
319
- final_text_list[bucket_i].extend(text_list)
320
-
321
- batch_accum[bucket_i] += total_mel_len
322
-
323
- if batch_accum[bucket_i] >= infer_batch_size:
324
- # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
- prompts_all.append((
326
- utts[bucket_i],
327
- ref_rms_list[bucket_i],
328
- padded_mel_batch(ref_mels[bucket_i]),
329
- ref_mel_lens[bucket_i],
330
- total_mel_lens[bucket_i],
331
- final_text_list[bucket_i]
332
- ))
333
- batch_accum[bucket_i] = 0
334
- utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
-
336
- # add residual
337
- for bucket_i, bucket_frames in enumerate(batch_accum):
338
- if bucket_frames > 0:
339
- prompts_all.append((
340
- utts[bucket_i],
341
- ref_rms_list[bucket_i],
342
- padded_mel_batch(ref_mels[bucket_i]),
343
- ref_mel_lens[bucket_i],
344
- total_mel_lens[bucket_i],
345
- final_text_list[bucket_i]
346
- ))
347
- # not only leave easy work for last workers
348
- random.seed(666)
349
- random.shuffle(prompts_all)
350
-
351
- return prompts_all
352
-
353
-
354
- # get wav_res_ref_text of seed-tts test metalst
355
- # https://github.com/BytedanceSpeech/seed-tts-eval
356
-
357
- def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
- f = open(metalst)
359
- lines = f.readlines()
360
- f.close()
361
-
362
- test_set_ = []
363
- for line in tqdm(lines):
364
- if len(line.strip().split('|')) == 5:
365
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
- elif len(line.strip().split('|')) == 4:
367
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
-
369
- if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
- continue
371
- gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
- if not os.path.isabs(prompt_wav):
373
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
-
375
- test_set_.append((gen_wav, prompt_wav, gt_text))
376
-
377
- num_jobs = len(gpus)
378
- if num_jobs == 1:
379
- return [(gpus[0], test_set_)]
380
-
381
- wav_per_job = len(test_set_) // num_jobs + 1
382
- test_set = []
383
- for i in range(num_jobs):
384
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
-
386
- return test_set
387
-
388
-
389
- # get librispeech test-clean cross sentence test
390
-
391
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
- f = open(metalst)
393
- lines = f.readlines()
394
- f.close()
395
-
396
- test_set_ = []
397
- for line in tqdm(lines):
398
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
-
400
- if eval_ground_truth:
401
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
- else:
404
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
- raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
- gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
-
408
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
-
411
- test_set_.append((gen_wav, ref_wav, gen_txt))
412
-
413
- num_jobs = len(gpus)
414
- if num_jobs == 1:
415
- return [(gpus[0], test_set_)]
416
-
417
- wav_per_job = len(test_set_) // num_jobs + 1
418
- test_set = []
419
- for i in range(num_jobs):
420
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
-
422
- return test_set
423
-
424
-
425
- # load asr model
426
-
427
- def load_asr_model(lang, ckpt_dir = ""):
428
- if lang == "zh":
429
- from funasr import AutoModel
430
- model = AutoModel(
431
- model = os.path.join(ckpt_dir, "paraformer-zh"),
432
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
- # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
- # spk_model = os.path.join(ckpt_dir, "cam++"),
435
- disable_update=True,
436
- ) # following seed-tts setting
437
- elif lang == "en":
438
- from faster_whisper import WhisperModel
439
- model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
- model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
- return model
442
-
443
-
444
- # WER Evaluation, the way Seed-TTS does
445
-
446
- def run_asr_wer(args):
447
- rank, lang, test_set, ckpt_dir = args
448
-
449
- if lang == "zh":
450
- import zhconv
451
- torch.cuda.set_device(rank)
452
- elif lang == "en":
453
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
- else:
455
- raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
-
457
- asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
-
459
- from zhon.hanzi import punctuation
460
- punctuation_all = punctuation + string.punctuation
461
- wers = []
462
-
463
- from jiwer import compute_measures
464
- for gen_wav, prompt_wav, truth in tqdm(test_set):
465
- if lang == "zh":
466
- res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
- hypo = res[0]["text"]
468
- hypo = zhconv.convert(hypo, 'zh-cn')
469
- elif lang == "en":
470
- segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
- hypo = ''
472
- for segment in segments:
473
- hypo = hypo + ' ' + segment.text
474
-
475
- # raw_truth = truth
476
- # raw_hypo = hypo
477
-
478
- for x in punctuation_all:
479
- truth = truth.replace(x, '')
480
- hypo = hypo.replace(x, '')
481
-
482
- truth = truth.replace(' ', ' ')
483
- hypo = hypo.replace(' ', ' ')
484
-
485
- if lang == "zh":
486
- truth = " ".join([x for x in truth])
487
- hypo = " ".join([x for x in hypo])
488
- elif lang == "en":
489
- truth = truth.lower()
490
- hypo = hypo.lower()
491
-
492
- measures = compute_measures(truth, hypo)
493
- wer = measures["wer"]
494
-
495
- # ref_list = truth.split(" ")
496
- # subs = measures["substitutions"] / len(ref_list)
497
- # dele = measures["deletions"] / len(ref_list)
498
- # inse = measures["insertions"] / len(ref_list)
499
-
500
- wers.append(wer)
501
-
502
- return wers
503
-
504
-
505
- # SIM Evaluation
506
-
507
- def run_sim(args):
508
- rank, test_set, ckpt_dir = args
509
- device = f"cuda:{rank}"
510
-
511
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
- model.load_state_dict(state_dict['model'], strict=False)
514
-
515
- use_gpu=True if torch.cuda.is_available() else False
516
- if use_gpu:
517
- model = model.cuda(device)
518
- model.eval()
519
-
520
- sim_list = []
521
- for wav1, wav2, truth in tqdm(test_set):
522
-
523
- wav1, sr1 = torchaudio.load(wav1)
524
- wav2, sr2 = torchaudio.load(wav2)
525
-
526
- resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
- resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
- wav1 = resample1(wav1)
529
- wav2 = resample2(wav2)
530
-
531
- if use_gpu:
532
- wav1 = wav1.cuda(device)
533
- wav2 = wav2.cuda(device)
534
- with torch.no_grad():
535
- emb1 = model(wav1)
536
- emb2 = model(wav2)
537
-
538
- sim = F.cosine_similarity(emb1, emb2)[0].item()
539
- # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
- sim_list.append(sim)
541
-
542
- return sim_list
543
-
544
-
545
  # filter func for dirty data with many repetitions
546
 
547
- def repetition_found(text, length = 2, tolerance = 10):
 
548
  pattern_count = defaultdict(int)
549
  for i in range(len(text) - length + 1):
550
- pattern = text[i:i + length]
551
  pattern_count[pattern] += 1
552
  for pattern, count in pattern_count.items():
553
  if count > tolerance:
554
  return True
555
  return False
556
-
557
-
558
- # load model checkpoint for inference
559
-
560
- def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
561
- if dtype is None:
562
- dtype = (
563
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
564
- )
565
- model = model.to(dtype)
566
-
567
- ckpt_type = ckpt_path.split(".")[-1]
568
- if ckpt_type == "safetensors":
569
- from safetensors.torch import load_file
570
-
571
- checkpoint = load_file(ckpt_path, device=device)
572
- else:
573
- checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
574
-
575
- if use_ema:
576
- if ckpt_type == "safetensors":
577
- checkpoint = {"ema_model_state_dict": checkpoint}
578
- checkpoint["model_state_dict"] = {
579
- k.replace("ema_model.", ""): v
580
- for k, v in checkpoint["ema_model_state_dict"].items()
581
- if k not in ["initted", "step"]
582
- }
583
-
584
- # patch for backward compatibility, 305e3ea
585
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
586
- if key in checkpoint["model_state_dict"]:
587
- del checkpoint["model_state_dict"][key]
588
-
589
- model.load_state_dict(checkpoint["model_state_dict"])
590
- else:
591
- if ckpt_type == "safetensors":
592
- checkpoint = {"model_state_dict": checkpoint}
593
- model.load_state_dict(checkpoint["model_state_dict"])
594
-
595
- del checkpoint
596
- torch.cuda.empty_cache()
597
-
598
- return model.to(device)
 
1
  from __future__ import annotations
2
 
3
  import os
 
 
4
  import random
 
 
5
  from collections import defaultdict
6
+ from importlib.resources import files
 
 
 
7
 
8
  import torch
 
9
  from torch.nn.utils.rnn import pad_sequence
 
 
 
 
10
 
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
 
 
 
14
 
15
  # seed everything
16
 
17
+
18
+ def seed_everything(seed=0):
19
  random.seed(seed)
20
+ os.environ["PYTHONHASHSEED"] = str(seed)
21
  torch.manual_seed(seed)
22
  torch.cuda.manual_seed(seed)
23
  torch.cuda.manual_seed_all(seed)
24
  torch.backends.cudnn.deterministic = True
25
  torch.backends.cudnn.benchmark = False
26
 
27
+
28
  # helpers
29
 
30
+
31
  def exists(v):
32
  return v is not None
33
 
34
+
35
  def default(v, d):
36
  return v if exists(v) else d
37
 
38
+
39
  # tensor helpers
40
 
 
 
 
 
41
 
42
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
  if not exists(length):
44
  length = t.amax()
45
 
46
+ seq = torch.arange(length, device=t.device)
47
+ return seq[None, :] < t[:, None]
48
+
49
+
50
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
+ max_seq_len = seq_len.max().item()
52
+ seq = torch.arange(max_seq_len, device=start.device).long()
53
+ start_mask = seq[None, :] >= start[:, None]
54
+ end_mask = seq[None, :] < end[:, None]
55
+ return start_mask & end_mask
56
+
57
+
58
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
 
 
 
59
  lengths = (frac_lengths * seq_len).long()
60
  max_start = seq_len - lengths
61
 
62
  rand = torch.rand_like(frac_lengths)
63
+ start = (max_start * rand).long().clamp(min=0)
64
  end = start + lengths
65
 
66
  return mask_from_start_end_indices(seq_len, start, end)
67
 
 
 
 
 
68
 
69
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
  if not exists(mask):
71
+ return t.mean(dim=1)
72
 
73
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
+ num = t.sum(dim=1)
75
+ den = mask.float().sum(dim=1)
76
 
77
+ return num / den.clamp(min=1.0)
78
 
79
 
80
  # simple utf-8 tokenizer, since paper went character based
81
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
 
 
 
84
  return text
85
 
86
+
87
  # char tokenizer, based on custom dataset's extracted .txt file
88
  def list_str_to_idx(
89
  text: list[str] | list[list[str]],
90
  vocab_char_map: dict[str, int], # {char: idx}
91
+ padding_value=-1,
92
+ ) -> int["b nt"]: # noqa: F722
93
  list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
  return text
96
 
97
 
98
  # Get tokenizer
99
 
100
+
101
  def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
+ """
103
  tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
  - "char" for char-wise tokenizer, need .txt vocab_file
105
  - "byte" for utf-8 tokenizer
106
  - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
  vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
  - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
+ - if use "byte", set to 256 (unicode byte range)
110
+ """
111
  if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
116
  vocab_char_map[char[:-1]] = i
 
120
  elif tokenizer == "byte":
121
  vocab_char_map = None
122
  vocab_size = 256
123
+
124
  elif tokenizer == "custom":
125
+ with open(dataset_name, "r", encoding="utf-8") as f:
126
  vocab_char_map = {}
127
  for i, char in enumerate(f):
128
  vocab_char_map[char[:-1]] = i
 
133
 
134
  # convert char to pinyin
135
 
136
+
137
+ def convert_char_to_pinyin(text_list, polyphone=True):
138
  final_text_list = []
139
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans(
140
+ {"“": '"', "”": '"', "‘": "'", "’": "'"}
141
+ ) # in case librispeech (orig no-pc) test-clean
142
+ custom_trans = str.maketrans({";": ","}) # add custom trans here, to address oov
143
  for text in text_list:
144
  char_list = []
145
  text = text.translate(god_knows_why_en_testset_contains_zh_quote)
146
  text = text.translate(custom_trans)
147
  for seg in jieba.cut(text):
148
+ seg_byte_len = len(bytes(seg, "UTF-8"))
149
  if seg_byte_len == len(seg): # if pure alphabets and symbols
150
  if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
151
  char_list.append(" ")
 
171
  return final_text_list
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # filter func for dirty data with many repetitions
175
 
176
+
177
+ def repetition_found(text, length=2, tolerance=10):
178
  pattern_count = defaultdict(int)
179
  for i in range(len(text) - length + 1):
180
+ pattern = text[i : i + length]
181
  pattern_count[pattern] += 1
182
  for pattern, count in pattern_count.items():
183
  if count > tolerance:
184
  return True
185
  return False