Support parallel execution of Silero VAD
Browse files- app.py +28 -13
- cli.py +3 -1
- src/modelCache.py +17 -0
- src/vad.py +60 -28
- src/vadParallel.py +93 -23
- src/whisperContainer.py +12 -26
- tests/vad_test.py +2 -2
app.py
CHANGED
@@ -6,10 +6,9 @@ from io import StringIO
|
|
6 |
import os
|
7 |
import pathlib
|
8 |
import tempfile
|
|
|
9 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
10 |
|
11 |
-
from src.whisperContainer import WhisperContainer, WhisperModelCache
|
12 |
-
|
13 |
# External programs
|
14 |
import ffmpeg
|
15 |
|
@@ -19,6 +18,7 @@ import gradio as gr
|
|
19 |
from src.download import ExceededMaximumDuration, download_url
|
20 |
from src.utils import slugify, write_srt, write_vtt
|
21 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
|
|
22 |
|
23 |
# Limitations (set to -1 to disable)
|
24 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
@@ -50,11 +50,13 @@ LANGUAGES = [
|
|
50 |
]
|
51 |
|
52 |
class WhisperTranscriber:
|
53 |
-
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
|
54 |
-
self.model_cache =
|
55 |
self.parallel_device_list = None
|
56 |
-
self.
|
|
|
57 |
self.vad_process_timeout = vad_process_timeout
|
|
|
58 |
|
59 |
self.vad_model = None
|
60 |
self.inputAudioMaxDuration = input_audio_max_duration
|
@@ -142,17 +144,27 @@ class WhisperTranscriber:
|
|
142 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
143 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
# Create parallel context if needed
|
146 |
-
if (self.
|
147 |
# Create a context wih processes and automatically clear the pool after 1 hour of inactivity
|
148 |
-
self.
|
|
|
|
|
|
|
149 |
|
150 |
parallel_vad = ParallelTranscription()
|
151 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
152 |
-
config=vadConfig,
|
|
|
153 |
|
154 |
def _has_parallel_devices(self):
|
155 |
-
return self.parallel_device_list is not None and len(self.parallel_device_list) > 0
|
156 |
|
157 |
def _concat_prompt(self, prompt1, prompt2):
|
158 |
if (prompt1 is None):
|
@@ -249,13 +261,15 @@ class WhisperTranscriber:
|
|
249 |
def close(self):
|
250 |
self.clear_cache()
|
251 |
|
252 |
-
if (self.
|
253 |
-
self.
|
|
|
|
|
254 |
|
255 |
|
256 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
257 |
-
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None):
|
258 |
-
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout)
|
259 |
|
260 |
# Specify a list of devices to use for parallel processing
|
261 |
ui.set_parallel_devices(vad_parallel_devices)
|
@@ -303,6 +317,7 @@ if __name__ == '__main__':
|
|
303 |
parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
|
304 |
parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
|
305 |
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
|
|
306 |
parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
|
307 |
|
308 |
args = parser.parse_args().__dict__
|
|
|
6 |
import os
|
7 |
import pathlib
|
8 |
import tempfile
|
9 |
+
from src.modelCache import ModelCache
|
10 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
11 |
|
|
|
|
|
12 |
# External programs
|
13 |
import ffmpeg
|
14 |
|
|
|
18 |
from src.download import ExceededMaximumDuration, download_url
|
19 |
from src.utils import slugify, write_srt, write_vtt
|
20 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
21 |
+
from src.whisperContainer import WhisperContainer
|
22 |
|
23 |
# Limitations (set to -1 to disable)
|
24 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
|
50 |
]
|
51 |
|
52 |
class WhisperTranscriber:
|
53 |
+
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
|
54 |
+
self.model_cache = ModelCache()
|
55 |
self.parallel_device_list = None
|
56 |
+
self.gpu_parallel_context = None
|
57 |
+
self.cpu_parallel_context = None
|
58 |
self.vad_process_timeout = vad_process_timeout
|
59 |
+
self.vad_cpu_cores = vad_cpu_cores
|
60 |
|
61 |
self.vad_model = None
|
62 |
self.inputAudioMaxDuration = input_audio_max_duration
|
|
|
144 |
# No parallel devices, so just run the VAD and Whisper in sequence
|
145 |
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
146 |
|
147 |
+
gpu_devices = self.parallel_device_list
|
148 |
+
|
149 |
+
if (gpu_devices is None or len(gpu_devices) == 0):
|
150 |
+
# No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
|
151 |
+
gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
|
152 |
+
|
153 |
# Create parallel context if needed
|
154 |
+
if (self.gpu_parallel_context is None):
|
155 |
# Create a context wih processes and automatically clear the pool after 1 hour of inactivity
|
156 |
+
self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
157 |
+
# We also need a CPU context for the VAD
|
158 |
+
if (self.cpu_parallel_context is None):
|
159 |
+
self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
160 |
|
161 |
parallel_vad = ParallelTranscription()
|
162 |
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
163 |
+
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
164 |
+
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
|
165 |
|
166 |
def _has_parallel_devices(self):
|
167 |
+
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
168 |
|
169 |
def _concat_prompt(self, prompt1, prompt2):
|
170 |
if (prompt1 is None):
|
|
|
261 |
def close(self):
|
262 |
self.clear_cache()
|
263 |
|
264 |
+
if (self.gpu_parallel_context is not None):
|
265 |
+
self.gpu_parallel_context.close()
|
266 |
+
if (self.cpu_parallel_context is not None):
|
267 |
+
self.cpu_parallel_context.close()
|
268 |
|
269 |
|
270 |
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
271 |
+
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None, vad_cpu_cores: int = 1):
|
272 |
+
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores)
|
273 |
|
274 |
# Specify a list of devices to use for parallel processing
|
275 |
ui.set_parallel_devices(vad_parallel_devices)
|
|
|
317 |
parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
|
318 |
parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
|
319 |
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
320 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
|
321 |
parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
|
322 |
|
323 |
args = parser.parse_args().__dict__
|
cli.py
CHANGED
@@ -32,6 +32,7 @@ def cli():
|
|
32 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
33 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
34 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
|
|
35 |
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
36 |
|
37 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
@@ -73,8 +74,9 @@ def cli():
|
|
73 |
vad_max_merge_size = args.pop("vad_max_merge_size")
|
74 |
vad_padding = args.pop("vad_padding")
|
75 |
vad_prompt_window = args.pop("vad_prompt_window")
|
|
|
76 |
|
77 |
-
model = WhisperContainer(model_name, device=device, download_root=model_dir)
|
78 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
79 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
80 |
|
|
|
32 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
33 |
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
34 |
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
35 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
|
36 |
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
37 |
|
38 |
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
|
|
74 |
vad_max_merge_size = args.pop("vad_max_merge_size")
|
75 |
vad_padding = args.pop("vad_padding")
|
76 |
vad_prompt_window = args.pop("vad_prompt_window")
|
77 |
+
vad_cpu_cores = args.pop("vad_cpu_cores")
|
78 |
|
79 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir, vad_cpu_cores=vad_cpu_cores)
|
80 |
transcriber = WhisperTranscriber(delete_uploaded_files=False)
|
81 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
82 |
|
src/modelCache.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ModelCache:
|
2 |
+
def __init__(self):
|
3 |
+
self._cache = dict()
|
4 |
+
|
5 |
+
def get(self, model_key: str, model_factory):
|
6 |
+
result = self._cache.get(model_key)
|
7 |
+
|
8 |
+
if result is None:
|
9 |
+
result = model_factory()
|
10 |
+
self._cache[model_key] = result
|
11 |
+
return result
|
12 |
+
|
13 |
+
def clear(self):
|
14 |
+
self._cache.clear()
|
15 |
+
|
16 |
+
# A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
|
17 |
+
GLOBAL_MODEL_CACHE = ModelCache()
|
src/vad.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from collections import Counter, deque
|
|
|
3 |
|
4 |
from typing import Any, Deque, Iterator, List, Dict
|
5 |
|
6 |
from pprint import pprint
|
|
|
7 |
|
8 |
from src.segments import merge_timestamps
|
9 |
from src.whisperContainer import WhisperCallback
|
@@ -76,7 +78,7 @@ class AbstractTranscription(ABC):
|
|
76 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
77 |
|
78 |
@abstractmethod
|
79 |
-
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
|
80 |
"""
|
81 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
82 |
|
@@ -93,10 +95,10 @@ class AbstractTranscription(ABC):
|
|
93 |
"""
|
94 |
return
|
95 |
|
96 |
-
def get_merged_timestamps(self,
|
97 |
"""
|
98 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
99 |
-
after merging the segments using the specified configuration.
|
100 |
|
101 |
Parameters
|
102 |
----------
|
@@ -109,21 +111,17 @@ class AbstractTranscription(ABC):
|
|
109 |
-------
|
110 |
A list of start and end timestamps, in fractional seconds.
|
111 |
"""
|
112 |
-
|
113 |
-
|
114 |
-
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size,
|
115 |
config.segment_padding_left, config.segment_padding_right)
|
116 |
|
117 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
118 |
-
max_audio_duration = get_audio_duration(audio)
|
119 |
-
|
120 |
# Expand segments to include the gaps between them
|
121 |
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
|
122 |
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
123 |
-
merged = self.fill_gaps(merged, total_duration=
|
124 |
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
|
125 |
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
126 |
-
merged = self.expand_gaps(merged, total_duration=
|
127 |
else:
|
128 |
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
|
129 |
|
@@ -147,8 +145,11 @@ class AbstractTranscription(ABC):
|
|
147 |
A list of start and end timestamps, in fractional seconds.
|
148 |
"""
|
149 |
|
|
|
|
|
|
|
150 |
# Get speech timestamps from full audio file
|
151 |
-
merged = self.get_merged_timestamps(
|
152 |
|
153 |
# A deque of transcribed segments that is passed to the next segment as a prompt
|
154 |
prompt_window = deque()
|
@@ -392,22 +393,41 @@ class AbstractTranscription(ABC):
|
|
392 |
|
393 |
|
394 |
class VadSileroTranscription(AbstractTranscription):
|
395 |
-
def __init__(self, sampling_rate: int = 16000):
|
396 |
super().__init__(sampling_rate=sampling_rate)
|
397 |
-
|
398 |
-
self.
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
result = []
|
405 |
|
|
|
|
|
|
|
406 |
# Divide procesisng of audio into chunks
|
407 |
-
chunk_start =
|
408 |
|
409 |
-
while (chunk_start <
|
410 |
-
chunk_duration = min(
|
411 |
|
412 |
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
|
413 |
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
|
@@ -421,23 +441,35 @@ class VadSileroTranscription(AbstractTranscription):
|
|
421 |
result.extend(adjusted)
|
422 |
chunk_start += chunk_duration
|
423 |
|
|
|
|
|
|
|
424 |
return result
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
# A very simple VAD that just marks every N seconds as speech
|
427 |
class VadPeriodicTranscription(AbstractTranscription):
|
428 |
def __init__(self, sampling_rate: int = 16000):
|
429 |
super().__init__(sampling_rate=sampling_rate)
|
430 |
|
431 |
-
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
|
432 |
-
# Get duration in seconds
|
433 |
-
audio_duration = get_audio_duration(audio)
|
434 |
result = []
|
435 |
|
436 |
# Generate a timestamp every N seconds
|
437 |
-
start_timestamp =
|
438 |
|
439 |
-
while (start_timestamp <
|
440 |
-
end_timestamp = min(start_timestamp + config.periodic_duration,
|
441 |
segment_duration = end_timestamp - start_timestamp
|
442 |
|
443 |
# Minimum duration is 1 second
|
|
|
1 |
from abc import ABC, abstractmethod
|
2 |
from collections import Counter, deque
|
3 |
+
import time
|
4 |
|
5 |
from typing import Any, Deque, Iterator, List, Dict
|
6 |
|
7 |
from pprint import pprint
|
8 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
9 |
|
10 |
from src.segments import merge_timestamps
|
11 |
from src.whisperContainer import WhisperCallback
|
|
|
78 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
79 |
|
80 |
@abstractmethod
|
81 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
82 |
"""
|
83 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
84 |
|
|
|
95 |
"""
|
96 |
return
|
97 |
|
98 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
|
99 |
"""
|
100 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
101 |
+
after merging the given segments using the specified configuration.
|
102 |
|
103 |
Parameters
|
104 |
----------
|
|
|
111 |
-------
|
112 |
A list of start and end timestamps, in fractional seconds.
|
113 |
"""
|
114 |
+
merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
|
|
|
|
|
115 |
config.segment_padding_left, config.segment_padding_right)
|
116 |
|
117 |
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
|
|
|
|
118 |
# Expand segments to include the gaps between them
|
119 |
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
|
120 |
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
121 |
+
merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
|
122 |
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
|
123 |
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
124 |
+
merged = self.expand_gaps(merged, total_duration=total_duration)
|
125 |
else:
|
126 |
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
|
127 |
|
|
|
145 |
A list of start and end timestamps, in fractional seconds.
|
146 |
"""
|
147 |
|
148 |
+
max_audio_duration = get_audio_duration(audio)
|
149 |
+
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
|
150 |
+
|
151 |
# Get speech timestamps from full audio file
|
152 |
+
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
|
153 |
|
154 |
# A deque of transcribed segments that is passed to the next segment as a prompt
|
155 |
prompt_window = deque()
|
|
|
393 |
|
394 |
|
395 |
class VadSileroTranscription(AbstractTranscription):
|
396 |
+
def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
|
397 |
super().__init__(sampling_rate=sampling_rate)
|
398 |
+
self.model = None
|
399 |
+
self.cache = cache
|
400 |
+
self._initialize_model()
|
401 |
+
|
402 |
+
def _initialize_model(self):
|
403 |
+
if (self.cache is not None):
|
404 |
+
model_key = "VadSileroTranscription"
|
405 |
+
self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
|
406 |
+
print("Loaded Silerio model from cache.")
|
407 |
+
else:
|
408 |
+
self.model, self.get_speech_timestamps = self._create_model()
|
409 |
+
print("Created Silerio model")
|
410 |
+
|
411 |
+
def _create_model(self):
|
412 |
+
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
413 |
+
|
414 |
+
# Silero does not benefit from multi-threading
|
415 |
+
torch.set_num_threads(1) # JIT
|
416 |
+
(get_speech_timestamps, _, _, _, _) = utils
|
417 |
+
|
418 |
+
return model, get_speech_timestamps
|
419 |
+
|
420 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
421 |
result = []
|
422 |
|
423 |
+
print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
|
424 |
+
perf_start_time = time.perf_counter()
|
425 |
+
|
426 |
# Divide procesisng of audio into chunks
|
427 |
+
chunk_start = start_time
|
428 |
|
429 |
+
while (chunk_start < end_time):
|
430 |
+
chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
|
431 |
|
432 |
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
|
433 |
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
|
|
|
441 |
result.extend(adjusted)
|
442 |
chunk_start += chunk_duration
|
443 |
|
444 |
+
perf_end_time = time.perf_counter()
|
445 |
+
print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
446 |
+
|
447 |
return result
|
448 |
|
449 |
+
def __getstate__(self):
|
450 |
+
# We only need the sampling rate
|
451 |
+
return { 'sampling_rate': self.sampling_rate }
|
452 |
+
|
453 |
+
def __setstate__(self, state):
|
454 |
+
self.sampling_rate = state['sampling_rate']
|
455 |
+
self.model = None
|
456 |
+
# Use the global cache
|
457 |
+
self.cache = GLOBAL_MODEL_CACHE
|
458 |
+
self._initialize_model()
|
459 |
+
|
460 |
# A very simple VAD that just marks every N seconds as speech
|
461 |
class VadPeriodicTranscription(AbstractTranscription):
|
462 |
def __init__(self, sampling_rate: int = 16000):
|
463 |
super().__init__(sampling_rate=sampling_rate)
|
464 |
|
465 |
+
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
|
|
|
|
|
466 |
result = []
|
467 |
|
468 |
# Generate a timestamp every N seconds
|
469 |
+
start_timestamp = start_time
|
470 |
|
471 |
+
while (start_timestamp < end_time):
|
472 |
+
end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
|
473 |
segment_duration = end_timestamp - start_timestamp
|
474 |
|
475 |
# Minimum duration is 1 second
|
src/vadParallel.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import multiprocessing
|
2 |
import threading
|
3 |
import time
|
4 |
-
from src.vad import AbstractTranscription, TranscriptionConfig
|
5 |
from src.whisperContainer import WhisperCallback
|
6 |
|
7 |
from multiprocessing import Pool
|
8 |
|
9 |
-
from typing import List
|
10 |
import os
|
11 |
|
12 |
|
@@ -76,19 +76,28 @@ class ParallelTranscriptionConfig(TranscriptionConfig):
|
|
76 |
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
77 |
self.device_id = device_id
|
78 |
self.override_timestamps = override_timestamps
|
79 |
-
|
80 |
class ParallelTranscription(AbstractTranscription):
|
|
|
|
|
|
|
|
|
81 |
def __init__(self, sampling_rate: int = 16000):
|
82 |
super().__init__(sampling_rate=sampling_rate)
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
# First, get the timestamps for the original audio
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
# Split into a list for each device
|
90 |
# TODO: Split by time instead of by number of chunks
|
91 |
-
merged_split = list(self._split(merged, len(
|
92 |
|
93 |
# Parameters that will be passed to the transcribe function
|
94 |
parameters = []
|
@@ -96,15 +105,15 @@ class ParallelTranscription(AbstractTranscription):
|
|
96 |
|
97 |
for i in range(len(merged_split)):
|
98 |
device_segment_list = list(merged_split[i])
|
99 |
-
device_id =
|
100 |
|
101 |
if (len(device_segment_list) <= 0):
|
102 |
continue
|
103 |
|
104 |
-
print("Device " + device_id + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
105 |
|
106 |
# Create a new config with the given device ID
|
107 |
-
device_config = ParallelTranscriptionConfig(
|
108 |
segment_index += len(device_segment_list)
|
109 |
|
110 |
parameters.append([audio, whisperCallable, device_config]);
|
@@ -119,12 +128,12 @@ class ParallelTranscription(AbstractTranscription):
|
|
119 |
|
120 |
# Spawn a separate process for each device
|
121 |
try:
|
122 |
-
if (
|
123 |
-
|
124 |
created_context = True
|
125 |
|
126 |
# Get a pool of processes
|
127 |
-
pool =
|
128 |
|
129 |
# Run the transcription in parallel
|
130 |
results = pool.starmap(self.transcribe, parameters)
|
@@ -140,29 +149,90 @@ class ParallelTranscription(AbstractTranscription):
|
|
140 |
|
141 |
finally:
|
142 |
# Return the pool to the context
|
143 |
-
if (
|
144 |
-
|
145 |
# Always close the context if we created it
|
146 |
if (created_context):
|
147 |
-
|
148 |
|
149 |
return merged
|
150 |
|
151 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
return []
|
153 |
|
154 |
-
def get_merged_timestamps(self,
|
155 |
# Override timestamps that will be processed
|
156 |
if (config.override_timestamps is not None):
|
157 |
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
158 |
return config.override_timestamps
|
159 |
-
return super().get_merged_timestamps(
|
160 |
|
161 |
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
|
162 |
-
# Override device ID
|
163 |
-
if (
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
return super().transcribe(audio, whisperCallable, config)
|
167 |
|
168 |
def _split(self, a, n):
|
|
|
1 |
import multiprocessing
|
2 |
import threading
|
3 |
import time
|
4 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
5 |
from src.whisperContainer import WhisperCallback
|
6 |
|
7 |
from multiprocessing import Pool
|
8 |
|
9 |
+
from typing import Any, Dict, List
|
10 |
import os
|
11 |
|
12 |
|
|
|
76 |
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
77 |
self.device_id = device_id
|
78 |
self.override_timestamps = override_timestamps
|
79 |
+
|
80 |
class ParallelTranscription(AbstractTranscription):
|
81 |
+
# Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
|
82 |
+
# into smaller segments than 2 minute (min 6 seconds per CPU core)
|
83 |
+
MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
|
84 |
+
|
85 |
def __init__(self, sampling_rate: int = 16000):
|
86 |
super().__init__(sampling_rate=sampling_rate)
|
87 |
|
88 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
89 |
+
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
|
90 |
+
total_duration = get_audio_duration(audio)
|
91 |
+
|
92 |
# First, get the timestamps for the original audio
|
93 |
+
if (cpu_device_count > 1):
|
94 |
+
merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
|
95 |
+
else:
|
96 |
+
merged = transcription.get_merged_timestamps(audio, config, total_duration)
|
97 |
|
98 |
# Split into a list for each device
|
99 |
# TODO: Split by time instead of by number of chunks
|
100 |
+
merged_split = list(self._split(merged, len(gpu_devices)))
|
101 |
|
102 |
# Parameters that will be passed to the transcribe function
|
103 |
parameters = []
|
|
|
105 |
|
106 |
for i in range(len(merged_split)):
|
107 |
device_segment_list = list(merged_split[i])
|
108 |
+
device_id = gpu_devices[i]
|
109 |
|
110 |
if (len(device_segment_list) <= 0):
|
111 |
continue
|
112 |
|
113 |
+
print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
114 |
|
115 |
# Create a new config with the given device ID
|
116 |
+
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
117 |
segment_index += len(device_segment_list)
|
118 |
|
119 |
parameters.append([audio, whisperCallable, device_config]);
|
|
|
128 |
|
129 |
# Spawn a separate process for each device
|
130 |
try:
|
131 |
+
if (gpu_parallel_context is None):
|
132 |
+
gpu_parallel_context = ParallelContext(len(gpu_devices))
|
133 |
created_context = True
|
134 |
|
135 |
# Get a pool of processes
|
136 |
+
pool = gpu_parallel_context.get_pool()
|
137 |
|
138 |
# Run the transcription in parallel
|
139 |
results = pool.starmap(self.transcribe, parameters)
|
|
|
149 |
|
150 |
finally:
|
151 |
# Return the pool to the context
|
152 |
+
if (gpu_parallel_context is not None):
|
153 |
+
gpu_parallel_context.return_pool(pool)
|
154 |
# Always close the context if we created it
|
155 |
if (created_context):
|
156 |
+
gpu_parallel_context.close()
|
157 |
|
158 |
return merged
|
159 |
|
160 |
+
def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
|
161 |
+
cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
|
162 |
+
parameters = []
|
163 |
+
|
164 |
+
chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
|
165 |
+
chunk_start = 0
|
166 |
+
cpu_device_id = 0
|
167 |
+
|
168 |
+
perf_start_time = time.perf_counter()
|
169 |
+
|
170 |
+
# Create chunks that will be processed on the CPU
|
171 |
+
while (chunk_start < total_duration):
|
172 |
+
chunk_end = min(chunk_start + chunk_size, total_duration)
|
173 |
+
|
174 |
+
print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
|
175 |
+
str(chunk_end) + " on CPU device " + str(cpu_device_id))
|
176 |
+
parameters.append([audio, config, chunk_start, chunk_end]);
|
177 |
+
|
178 |
+
cpu_device_id += 1
|
179 |
+
chunk_start = chunk_end
|
180 |
+
|
181 |
+
created_context = False
|
182 |
+
|
183 |
+
# Spawn a separate process for each device
|
184 |
+
try:
|
185 |
+
if (cpu_parallel_context is None):
|
186 |
+
cpu_parallel_context = ParallelContext(cpu_device_count)
|
187 |
+
created_context = True
|
188 |
+
|
189 |
+
# Get a pool of processes
|
190 |
+
pool = cpu_parallel_context.get_pool()
|
191 |
+
|
192 |
+
# Run the transcription in parallel. Note that transcription must be picklable.
|
193 |
+
results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
|
194 |
+
|
195 |
+
timestamps = []
|
196 |
+
|
197 |
+
# Flatten the results
|
198 |
+
for result in results:
|
199 |
+
timestamps.extend(result)
|
200 |
+
|
201 |
+
merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
|
202 |
+
|
203 |
+
perf_end_time = time.perf_counter()
|
204 |
+
print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
205 |
+
return merged
|
206 |
+
|
207 |
+
finally:
|
208 |
+
# Return the pool to the context
|
209 |
+
if (cpu_parallel_context is not None):
|
210 |
+
cpu_parallel_context.return_pool(pool)
|
211 |
+
# Always close the context if we created it
|
212 |
+
if (created_context):
|
213 |
+
cpu_parallel_context.close()
|
214 |
+
|
215 |
+
def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
|
216 |
return []
|
217 |
|
218 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
219 |
# Override timestamps that will be processed
|
220 |
if (config.override_timestamps is not None):
|
221 |
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
222 |
return config.override_timestamps
|
223 |
+
return super().get_merged_timestamps(timestamps, config, total_duration)
|
224 |
|
225 |
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
|
226 |
+
# Override device ID the first time
|
227 |
+
if (os.environ.get("INITIALIZED", None) is None):
|
228 |
+
os.environ["INITIALIZED"] = "1"
|
229 |
+
|
230 |
+
# Note that this may be None if the user didn't specify a device. In that case, Whisper will
|
231 |
+
# just use the default GPU device.
|
232 |
+
if (config.device_id is not None):
|
233 |
+
print("Using device " + config.device_id)
|
234 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
235 |
+
|
236 |
return super().transcribe(audio, whisperCallable, config)
|
237 |
|
238 |
def _split(self, a, n):
|
src/whisperContainer.py
CHANGED
@@ -1,29 +1,10 @@
|
|
1 |
# External programs
|
2 |
import whisper
|
3 |
|
4 |
-
|
5 |
-
def __init__(self):
|
6 |
-
self._cache = dict()
|
7 |
-
|
8 |
-
def get(self, model_name, device: str = None):
|
9 |
-
key = model_name + ":" + (device if device else '')
|
10 |
-
|
11 |
-
result = self._cache.get(key)
|
12 |
-
|
13 |
-
if result is None:
|
14 |
-
print("Loading whisper model " + model_name)
|
15 |
-
result = whisper.load_model(name=model_name, device=device)
|
16 |
-
self._cache[key] = result
|
17 |
-
return result
|
18 |
-
|
19 |
-
def clear(self):
|
20 |
-
self._cache.clear()
|
21 |
-
|
22 |
-
# A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
|
23 |
-
GLOBAL_WHISPER_MODEL_CACHE = WhisperModelCache()
|
24 |
|
25 |
class WhisperContainer:
|
26 |
-
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache:
|
27 |
self.model_name = model_name
|
28 |
self.device = device
|
29 |
self.download_root = download_root
|
@@ -36,12 +17,16 @@ class WhisperContainer:
|
|
36 |
if self.model is None:
|
37 |
|
38 |
if (self.cache is None):
|
39 |
-
|
40 |
-
self.model = whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
41 |
else:
|
42 |
-
|
|
|
43 |
return self.model
|
44 |
|
|
|
|
|
|
|
|
|
45 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
46 |
"""
|
47 |
Create a WhisperCallback object that can be used to transcript audio files.
|
@@ -65,14 +50,15 @@ class WhisperContainer:
|
|
65 |
|
66 |
# This is required for multiprocessing
|
67 |
def __getstate__(self):
|
68 |
-
return { "model_name": self.model_name, "device": self.device }
|
69 |
|
70 |
def __setstate__(self, state):
|
71 |
self.model_name = state["model_name"]
|
72 |
self.device = state["device"]
|
|
|
73 |
self.model = None
|
74 |
# Depickled objects must use the global cache
|
75 |
-
self.cache =
|
76 |
|
77 |
|
78 |
class WhisperCallback:
|
|
|
1 |
# External programs
|
2 |
import whisper
|
3 |
|
4 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class WhisperContainer:
|
7 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
|
8 |
self.model_name = model_name
|
9 |
self.device = device
|
10 |
self.download_root = download_root
|
|
|
17 |
if self.model is None:
|
18 |
|
19 |
if (self.cache is None):
|
20 |
+
self.model = self._create_model()
|
|
|
21 |
else:
|
22 |
+
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
23 |
+
self.model = self.cache.get(model_key, self._create_model)
|
24 |
return self.model
|
25 |
|
26 |
+
def _create_model(self):
|
27 |
+
print("Loading whisper model " + self.model_name)
|
28 |
+
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
29 |
+
|
30 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
31 |
"""
|
32 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
50 |
|
51 |
# This is required for multiprocessing
|
52 |
def __getstate__(self):
|
53 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
|
54 |
|
55 |
def __setstate__(self, state):
|
56 |
self.model_name = state["model_name"]
|
57 |
self.device = state["device"]
|
58 |
+
self.download_root = state["download_root"]
|
59 |
self.model = None
|
60 |
# Depickled objects must use the global cache
|
61 |
+
self.cache = GLOBAL_MODEL_CACHE
|
62 |
|
63 |
|
64 |
class WhisperCallback:
|
tests/vad_test.py
CHANGED
@@ -5,7 +5,7 @@ import sys
|
|
5 |
|
6 |
sys.path.append('../whisper-webui')
|
7 |
|
8 |
-
from src.vad import AbstractTranscription, VadSileroTranscription
|
9 |
|
10 |
class TestVad(unittest.TestCase):
|
11 |
def __init__(self, *args, **kwargs):
|
@@ -55,7 +55,7 @@ class MockVadTranscription(AbstractTranscription):
|
|
55 |
# For mocking, this just returns a simple numppy array
|
56 |
return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
|
57 |
|
58 |
-
def get_transcribe_timestamps(self, audio: str):
|
59 |
result = []
|
60 |
|
61 |
result.append( { 'start': 30, 'end': 60 } )
|
|
|
5 |
|
6 |
sys.path.append('../whisper-webui')
|
7 |
|
8 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
9 |
|
10 |
class TestVad(unittest.TestCase):
|
11 |
def __init__(self, *args, **kwargs):
|
|
|
55 |
# For mocking, this just returns a simple numppy array
|
56 |
return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
|
57 |
|
58 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
|
59 |
result = []
|
60 |
|
61 |
result.append( { 'start': 30, 'end': 60 } )
|