trying to add sentence chunking

#1
by drewThomasson - opened
Files changed (1) hide show
  1. app.py +146 -27
app.py CHANGED
@@ -3,9 +3,10 @@ import base64
3
  import time
4
  import uuid
5
  import shutil
 
6
  from concurrent.futures import ThreadPoolExecutor
7
  from pathlib import Path
8
- from typing import List, Optional
9
  import subprocess
10
 
11
  import ebooklib
@@ -74,14 +75,99 @@ def clone_voice(audio_path: str):
74
  audio_data = base64.b64encode(f.read()).decode('utf-8')
75
  return audio_data
76
 
77
- def process_text_and_generate(input_text, ref_audio_files, speed, enhance_speech, temperature, top_p, top_k, repetition_penalty, language, *args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  """Process text and generate audio."""
79
  log_messages = ""
80
  if not ref_audio_files:
81
  log_messages += "Please provide at least one reference audio!\n"
82
  return None, log_messages
83
 
84
- # clone voices from all file paths (shorten them)
85
  base64_voices = ref_audio_files[:5]
86
 
87
  request = TTSRequest(
@@ -109,7 +195,7 @@ def process_text_and_generate(input_text, ref_audio_files, speed, enhance_speech
109
  return None, log_messages
110
  except Exception as e:
111
  logger.error(f"Error: {e}")
112
- log_messages += f"❌ An Error occured: {e}\n"
113
  return None, log_messages
114
 
115
  def build_gradio_ui():
@@ -187,26 +273,37 @@ def build_gradio_ui():
187
  generate_button = gr.Button("Generate Speech")
188
  with gr.Column():
189
  audio_output = gr.Audio(label="Generated Audio")
190
- log_output = gr.Text(label="Log Output")
191
 
192
  def process_file_and_generate(
193
  file_input, ref_audio_files, speed, enhance_speech,
194
  temperature, top_p, top_k, repetition_penalty, language
195
  ):
196
  if not file_input:
197
- return None, "Please provide an input file!"
198
 
199
  try:
200
  # Convert input file to text
201
  input_text = text_from_file(file_input.name)
202
 
203
- return process_text_and_generate(
204
- input_text, ref_audio_files, speed, enhance_speech,
205
- temperature, top_p, top_k, repetition_penalty, language
 
 
 
 
206
  )
 
 
 
 
 
 
 
207
  except Exception as e:
208
  logger.error(f"Error processing file: {e}")
209
- return None, f"Error processing file: {str(e)}"
210
 
211
  generate_button.click(
212
  process_file_and_generate,
@@ -229,7 +326,8 @@ def build_gradio_ui():
229
  )
230
  mic_ref_audio = gr.Audio(
231
  label="Record Reference Audio",
232
- sources=["microphone"]
 
233
  )
234
 
235
  with gr.Accordion("Advanced settings", open=False):
@@ -283,16 +381,16 @@ def build_gradio_ui():
283
  generate_button_mic = gr.Button("Generate Speech")
284
  with gr.Column():
285
  audio_output_mic = gr.Audio(label="Generated Audio")
286
- log_output_mic = gr.Text(label="Log Output")
287
 
288
  def process_mic_and_generate(
289
  file_input, mic_ref_audio, speed_mic, enhance_speech_mic,
290
  temperature_mic, top_p_mic, top_k_mic, repetition_penalty_mic, language_mic
291
  ):
292
- if not mic_ref_audio:
293
- return None, "Please record an audio!"
294
  if not file_input:
295
- return None, "Please provide an input file!"
296
 
297
  try:
298
  # Convert input file to text
@@ -303,21 +401,42 @@ def build_gradio_ui():
303
  hash = hashlib.sha1(data).hexdigest()[:10]
304
  output_path = temp_dir / (f"mic_{hash}.wav")
305
 
306
- torch_audio = torch.from_numpy(mic_ref_audio[1].astype(float))
307
- torchaudio.save(
308
- str(output_path),
309
- torch_audio.unsqueeze(0),
310
- mic_ref_audio[0]
311
- )
 
 
 
 
 
 
 
 
 
 
312
 
313
- return process_text_and_generate(
314
- input_text, [Path(output_path)], speed_mic,
315
- enhance_speech_mic, temperature_mic, top_p_mic,
316
- top_k_mic, repetition_penalty_mic, language_mic
 
 
 
 
317
  )
 
 
 
 
 
 
 
318
  except Exception as e:
319
  logger.error(f"Error processing input: {e}")
320
- return None, f"Error processing input: {str(e)}"
321
 
322
  generate_button_mic.click(
323
  process_mic_and_generate,
@@ -333,4 +452,4 @@ def build_gradio_ui():
333
 
334
  if __name__ == "__main__":
335
  ui = build_gradio_ui()
336
- ui.launch(debug=True, server_name="0.0.0.0", server_port=7860)
 
3
  import time
4
  import uuid
5
  import shutil
6
+ import hashlib
7
  from concurrent.futures import ThreadPoolExecutor
8
  from pathlib import Path
9
+ from typing import List, Optional, Tuple
10
  import subprocess
11
 
12
  import ebooklib
 
75
  audio_data = base64.b64encode(f.read()).decode('utf-8')
76
  return audio_data
77
 
78
+ def chunk_text(text: str, max_words: int = 300) -> List[str]:
79
+ """
80
+ Splits the input text into chunks with a maximum of `max_words` per chunk.
81
+ """
82
+ words = text.split()
83
+ chunks = []
84
+ for i in range(0, len(words), max_words):
85
+ chunk = ' '.join(words[i:i + max_words])
86
+ chunks.append(chunk)
87
+ return chunks
88
+
89
+ def generate_audio_from_chunks(
90
+ chunks: List[str],
91
+ ref_audio_files: List[str],
92
+ speed: float,
93
+ enhance_speech: bool,
94
+ temperature: float,
95
+ top_p: float,
96
+ top_k: int,
97
+ repetition_penalty: float,
98
+ language: str
99
+ ) -> Tuple[Optional[str], str]:
100
+ """
101
+ Generates audio for each text chunk and combines them into a single audio file.
102
+ Returns the path to the combined audio file and a log message.
103
+ """
104
+ audio_files = []
105
+ log_messages = ""
106
+
107
+ for idx, chunk in enumerate(chunks):
108
+ result, log = process_text_and_generate(
109
+ chunk, ref_audio_files, speed, enhance_speech, temperature,
110
+ top_p, top_k, repetition_penalty, language
111
+ )
112
+ if result:
113
+ sample_rate, audio_array = result
114
+ # Save audio array to temp file
115
+ audio_path = temp_dir / f"chunk_{uuid.uuid4().hex[:8]}_{idx}.wav"
116
+ audio_tensor = torch.from_numpy(audio_array)
117
+ torchaudio.save(str(audio_path), audio_tensor.unsqueeze(0), sample_rate)
118
+ audio_files.append(str(audio_path))
119
+ log_messages += f"βœ… Generated audio for chunk {idx + 1}/{len(chunks)}\n"
120
+ else:
121
+ logger.error(f"Failed to generate audio for chunk {idx}: {log}")
122
+ log_messages += f"❌ Failed to generate audio for chunk {idx + 1}: {log}\n"
123
+ return None, log_messages
124
+
125
+ # Create a list file for ffmpeg
126
+ list_file = temp_dir / f"list_{uuid.uuid4().hex[:8]}.txt"
127
+ with open(list_file, 'w') as f:
128
+ for audio_file in audio_files:
129
+ f.write(f"file '{audio_file}'\n")
130
+
131
+ # Define the output combined audio path
132
+ combined_audio_path = temp_dir / f"combined_{uuid.uuid4().hex[:8]}.wav"
133
+
134
+ try:
135
+ subprocess.run(
136
+ [
137
+ 'ffmpeg', '-y', '-f', 'concat', '-safe', '0',
138
+ '-i', str(list_file),
139
+ '-c', 'copy',
140
+ str(combined_audio_path)
141
+ ],
142
+ check=True,
143
+ capture_output=True,
144
+ text=True
145
+ )
146
+ log_messages += "βœ… Successfully combined all audio chunks."
147
+ return str(combined_audio_path), log_messages
148
+ except subprocess.CalledProcessError as e:
149
+ logger.error(f"Failed to combine audio files: {e.stderr}")
150
+ log_messages += f"❌ Failed to combine audio files: {e.stderr}"
151
+ return None, log_messages
152
+
153
+ def process_text_and_generate(
154
+ input_text: str,
155
+ ref_audio_files: List[str],
156
+ speed: float,
157
+ enhance_speech: bool,
158
+ temperature: float,
159
+ top_p: float,
160
+ top_k: int,
161
+ repetition_penalty: float,
162
+ language: str
163
+ ) -> Tuple[Optional[Tuple[int, np.ndarray]], str]:
164
  """Process text and generate audio."""
165
  log_messages = ""
166
  if not ref_audio_files:
167
  log_messages += "Please provide at least one reference audio!\n"
168
  return None, log_messages
169
 
170
+ # Clone voices from all file paths (shorten them)
171
  base64_voices = ref_audio_files[:5]
172
 
173
  request = TTSRequest(
 
195
  return None, log_messages
196
  except Exception as e:
197
  logger.error(f"Error: {e}")
198
+ log_messages += f"❌ An Error occurred: {e}\n"
199
  return None, log_messages
200
 
201
  def build_gradio_ui():
 
273
  generate_button = gr.Button("Generate Speech")
274
  with gr.Column():
275
  audio_output = gr.Audio(label="Generated Audio")
276
+ log_output = gr.Textbox(label="Log Output", lines=10)
277
 
278
  def process_file_and_generate(
279
  file_input, ref_audio_files, speed, enhance_speech,
280
  temperature, top_p, top_k, repetition_penalty, language
281
  ):
282
  if not file_input:
283
+ return None, "❌ Please provide an input file!"
284
 
285
  try:
286
  # Convert input file to text
287
  input_text = text_from_file(file_input.name)
288
 
289
+ # Chunk the text
290
+ chunks = chunk_text(input_text, max_words=300)
291
+
292
+ # Generate audio from chunks and combine
293
+ combined_audio_path, log = generate_audio_from_chunks(
294
+ chunks, ref_audio_files, speed, enhance_speech, temperature, top_p,
295
+ top_k, repetition_penalty, language
296
  )
297
+
298
+ if combined_audio_path:
299
+ # Read the combined audio file to return as audio output
300
+ waveform, sr = torchaudio.load(combined_audio_path)
301
+ return (sr, waveform.numpy()), log
302
+ else:
303
+ return None, log
304
  except Exception as e:
305
  logger.error(f"Error processing file: {e}")
306
+ return None, f"❌ Error processing file: {str(e)}"
307
 
308
  generate_button.click(
309
  process_file_and_generate,
 
326
  )
327
  mic_ref_audio = gr.Audio(
328
  label="Record Reference Audio",
329
+ source="microphone",
330
+ type="numpy"
331
  )
332
 
333
  with gr.Accordion("Advanced settings", open=False):
 
381
  generate_button_mic = gr.Button("Generate Speech")
382
  with gr.Column():
383
  audio_output_mic = gr.Audio(label="Generated Audio")
384
+ log_output_mic = gr.Textbox(label="Log Output", lines=10)
385
 
386
  def process_mic_and_generate(
387
  file_input, mic_ref_audio, speed_mic, enhance_speech_mic,
388
  temperature_mic, top_p_mic, top_k_mic, repetition_penalty_mic, language_mic
389
  ):
390
+ if mic_ref_audio is None:
391
+ return None, "❌ Please record an audio!"
392
  if not file_input:
393
+ return None, "❌ Please provide an input file!"
394
 
395
  try:
396
  # Convert input file to text
 
401
  hash = hashlib.sha1(data).hexdigest()[:10]
402
  output_path = temp_dir / (f"mic_{hash}.wav")
403
 
404
+ # Ensure mic_ref_audio is in the correct format
405
+ if isinstance(mic_ref_audio, tuple):
406
+ mic_waveform, mic_sr = mic_ref_audio
407
+ torch_audio = torch.from_numpy(mic_waveform.astype(float))
408
+ torchaudio.save(
409
+ str(output_path),
410
+ torch_audio.unsqueeze(0),
411
+ mic_sr
412
+ )
413
+ else:
414
+ # If mic_ref_audio is not a tuple, handle accordingly
415
+ logger.error("Invalid microphone audio format.")
416
+ return None, "❌ Invalid microphone audio format."
417
+
418
+ # Clone voice from the saved mic audio
419
+ ref_audio_files = [str(output_path)]
420
 
421
+ # Chunk the text
422
+ chunks = chunk_text(input_text, max_words=300)
423
+
424
+ # Generate audio from chunks and combine
425
+ combined_audio_path, log = generate_audio_from_chunks(
426
+ chunks, ref_audio_files, speed_mic, enhance_speech_mic,
427
+ temperature_mic, top_p_mic, top_k_mic, repetition_penalty_mic,
428
+ language_mic
429
  )
430
+
431
+ if combined_audio_path:
432
+ # Read the combined audio file to return as audio output
433
+ waveform, sr = torchaudio.load(combined_audio_path)
434
+ return (sr, waveform.numpy()), log
435
+ else:
436
+ return None, log
437
  except Exception as e:
438
  logger.error(f"Error processing input: {e}")
439
+ return None, f"❌ Error processing input: {str(e)}"
440
 
441
  generate_button_mic.click(
442
  process_mic_and_generate,
 
452
 
453
  if __name__ == "__main__":
454
  ui = build_gradio_ui()
455
+ ui.launch(debug=True, server_name="0.0.0.0", server_port=7860)