mrfakename commited on
Commit
cc950e4
1 Parent(s): 1dda0cc

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.1.1"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.1.2"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/model/dataset.py CHANGED
@@ -127,38 +127,43 @@ class CustomDataset(Dataset):
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
- row = self.data[index]
131
- audio_path = row["audio_path"]
132
- text = row["text"]
133
- duration = row["duration"]
 
 
 
 
 
 
 
134
 
135
  if self.preprocessed_mel:
136
  mel_spec = torch.tensor(row["mel_spec"])
137
-
138
  else:
139
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
140
  if audio.shape[0] > 1:
141
  audio = torch.mean(audio, dim=0, keepdim=True)
142
 
143
- if duration > 30 or duration < 0.3:
144
- return self.__getitem__((index + 1) % len(self.data))
145
-
146
  if source_sample_rate != self.target_sample_rate:
147
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148
  audio = resampler(audio)
149
 
 
150
  mel_spec = self.mel_spectrogram(audio)
151
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152
 
153
- return dict(
154
- mel_spec=mel_spec,
155
- text=text,
156
- )
157
 
158
 
159
  # Dynamic Batch Sampler
160
-
161
-
162
  class DynamicBatchSampler(Sampler[list[int]]):
163
  """Extension of Sampler that will do the following:
164
  1. Change the batch size (essentially number of sequences)
 
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
+ while True:
131
+ row = self.data[index]
132
+ audio_path = row["audio_path"]
133
+ text = row["text"]
134
+ duration = row["duration"]
135
+
136
+ # filter by given length
137
+ if 0.3 <= duration <= 30:
138
+ break # valid
139
+
140
+ index = (index + 1) % len(self.data)
141
 
142
  if self.preprocessed_mel:
143
  mel_spec = torch.tensor(row["mel_spec"])
 
144
  else:
145
  audio, source_sample_rate = torchaudio.load(audio_path)
146
+
147
+ # make sure mono input
148
  if audio.shape[0] > 1:
149
  audio = torch.mean(audio, dim=0, keepdim=True)
150
 
151
+ # resample if necessary
 
 
152
  if source_sample_rate != self.target_sample_rate:
153
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
154
  audio = resampler(audio)
155
 
156
+ # to mel spectrogram
157
  mel_spec = self.mel_spectrogram(audio)
158
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
159
 
160
+ return {
161
+ "mel_spec": mel_spec,
162
+ "text": text,
163
+ }
164
 
165
 
166
  # Dynamic Batch Sampler
 
 
167
  class DynamicBatchSampler(Sampler[list[int]]):
168
  """Extension of Sampler that will do the following:
169
  1. Change the batch size (essentially number of sequences)
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1177,7 +1177,10 @@ def get_random_sample_transcribe(project_name):
1177
  sp = item.split("|")
1178
  if len(sp) != 2:
1179
  continue
1180
- list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
 
 
 
1181
 
1182
  if list_data == []:
1183
  return "", None
 
1177
  sp = item.split("|")
1178
  if len(sp) != 2:
1179
  continue
1180
+
1181
+ # fixed audio when it is absolute
1182
+ file_audio = get_correct_audio_path(sp[0], path_project)
1183
+ list_data.append([file_audio, sp[1]])
1184
 
1185
  if list_data == []:
1186
  return "", None