aadnk commited on
Commit
c90f138
1 Parent(s): c963436

Ensure progress bar works for multiple files

Browse files
Files changed (2) hide show
  1. app.py +23 -6
  2. src/source.py +22 -12
app.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
 
13
  import torch
14
  from src.config import ApplicationConfig
15
- from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle
16
  from src.modelCache import ModelCache
17
  from src.source import get_audio_source_collection
18
  from src.vadParallel import ParallelContext, ParallelTranscription
@@ -135,9 +135,17 @@ class WhisperTranscriber:
135
 
136
  outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
137
 
 
 
 
 
 
 
 
138
  # Execute whisper
139
  for source in sources:
140
  source_prefix = ""
 
141
 
142
  if (len(sources) > 1):
143
  # Prefix (minimum 2 digits)
@@ -145,10 +153,18 @@ class WhisperTranscriber:
145
  source_prefix = str(source_index).zfill(2) + "_"
146
  print("Transcribing ", source.source_path)
147
 
 
 
 
 
 
148
  # Transcribe
149
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, progress, **decodeOptions)
150
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
151
 
 
 
 
152
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
153
 
154
  if len(sources) > 1:
@@ -209,19 +225,20 @@ class WhisperTranscriber:
209
 
210
  def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
211
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
212
- progress: gr.Progress = None, **decodeOptions: dict):
213
 
214
  initial_prompt = decodeOptions.pop('initial_prompt', None)
215
 
 
 
 
 
216
  if ('task' in decodeOptions):
217
  task = decodeOptions.pop('task')
218
 
219
  # Callable for processing an audio file
220
  whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
221
 
222
- # A listener that will report progress to Gradio
223
- progressListener = self._create_progress_listener(progress)
224
-
225
  # The results
226
  if (vad == 'silero-vad'):
227
  # Silero VAD where non-speech gaps are transcribed
 
12
 
13
  import torch
14
  from src.config import ApplicationConfig
15
+ from src.hooks.whisperProgressHook import ProgressListener, SubTaskProgressListener, create_progress_listener_handle
16
  from src.modelCache import ModelCache
17
  from src.source import get_audio_source_collection
18
  from src.vadParallel import ParallelContext, ParallelTranscription
 
135
 
136
  outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
137
 
138
+ # Progress
139
+ total_duration = sum([source.get_audio_duration() for source in sources])
140
+ current_progress = 0
141
+
142
+ # A listener that will report progress to Gradio
143
+ root_progress_listener = self._create_progress_listener(progress)
144
+
145
  # Execute whisper
146
  for source in sources:
147
  source_prefix = ""
148
+ source_audio_duration = source.get_audio_duration()
149
 
150
  if (len(sources) > 1):
151
  # Prefix (minimum 2 digits)
 
153
  source_prefix = str(source_index).zfill(2) + "_"
154
  print("Transcribing ", source.source_path)
155
 
156
+ scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
157
+ base_task_total=total_duration,
158
+ sub_task_start=current_progress,
159
+ sub_task_total=source_audio_duration)
160
+
161
  # Transcribe
162
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, scaled_progress_listener, **decodeOptions)
163
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
164
 
165
+ # Update progress
166
+ current_progress += source_audio_duration
167
+
168
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
169
 
170
  if len(sources) > 1:
 
225
 
226
  def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
227
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
228
+ progressListener: ProgressListener = None, **decodeOptions: dict):
229
 
230
  initial_prompt = decodeOptions.pop('initial_prompt', None)
231
 
232
+ if progressListener is None:
233
+ # Default progress listener
234
+ progressListener = ProgressListener()
235
+
236
  if ('task' in decodeOptions):
237
  task = decodeOptions.pop('task')
238
 
239
  # Callable for processing an audio file
240
  whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
241
 
 
 
 
242
  # The results
243
  if (vad == 'silero-vad'):
244
  # Silero VAD where non-speech gaps are transcribed
src/source.py CHANGED
@@ -12,15 +12,22 @@ from src.download import ExceededMaximumDuration, download_url
12
  MAX_FILE_PREFIX_LENGTH = 17
13
 
14
  class AudioSource:
15
- def __init__(self, source_path, source_name = None):
16
  self.source_path = source_path
17
  self.source_name = source_name
 
18
 
19
  # Load source name if not provided
20
  if (self.source_name is None):
21
  file_path = pathlib.Path(self.source_path)
22
  self.source_name = file_path.name
23
 
 
 
 
 
 
 
24
  def get_full_name(self):
25
  return self.source_name
26
 
@@ -53,18 +60,21 @@ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneDat
53
  if (microphoneData is not None):
54
  output.append(AudioSource(microphoneData))
55
 
56
- total_duration = 0
 
 
 
 
 
 
 
 
 
57
 
58
- # Calculate total audio length. We do this even if input_audio_max_duration
59
- # is disabled to ensure that all the audio files are valid.
60
- for source in output:
61
- audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
62
- total_duration += float(audioDuration)
63
 
64
- # Ensure the total duration of the audio is not too long
65
- if input_audio_max_duration > 0:
66
- if float(total_duration) > input_audio_max_duration:
67
- raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
68
-
69
  # Return a list of audio sources
70
  return output
 
12
  MAX_FILE_PREFIX_LENGTH = 17
13
 
14
  class AudioSource:
15
+ def __init__(self, source_path, source_name = None, audio_duration = None):
16
  self.source_path = source_path
17
  self.source_name = source_name
18
+ self._audio_duration = audio_duration
19
 
20
  # Load source name if not provided
21
  if (self.source_name is None):
22
  file_path = pathlib.Path(self.source_path)
23
  self.source_name = file_path.name
24
 
25
+ def get_audio_duration(self):
26
+ if self._audio_duration is None:
27
+ self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
28
+
29
+ return self._audio_duration
30
+
31
  def get_full_name(self):
32
  return self.source_name
33
 
 
60
  if (microphoneData is not None):
61
  output.append(AudioSource(microphoneData))
62
 
63
+ total_duration = 0
64
+
65
+ # Calculate total audio length. We do this even if input_audio_max_duration
66
+ # is disabled to ensure that all the audio files are valid.
67
+ for source in output:
68
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
69
+ total_duration += float(audioDuration)
70
+
71
+ # Save audio duration
72
+ source._audio_duration = float(audioDuration)
73
 
74
+ # Ensure the total duration of the audio is not too long
75
+ if input_audio_max_duration > 0:
76
+ if float(total_duration) > input_audio_max_duration:
77
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
 
78
 
 
 
 
 
 
79
  # Return a list of audio sources
80
  return output