Cleanup code
Browse files- app.py +13 -13
- src/segments.py +9 -1
- src/vad.py +42 -49
app.py
CHANGED
@@ -14,7 +14,7 @@ import gradio as gr
|
|
14 |
|
15 |
from src.download import ExceededMaximumDuration, download_url
|
16 |
from src.utils import slugify, write_srt, write_vtt
|
17 |
-
from src.vad import NonSpeechStrategy, VadPeriodicTranscription, VadSileroTranscription
|
18 |
|
19 |
# Limitations (set to -1 to disable)
|
20 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
@@ -96,38 +96,38 @@ class WhisperTranscriber:
|
|
96 |
# The results
|
97 |
if (vad == 'silero-vad'):
|
98 |
# Silero VAD where non-speech gaps are transcribed
|
99 |
-
process_gaps = self.
|
100 |
-
result =
|
101 |
elif (vad == 'silero-vad-skip-gaps'):
|
102 |
# Silero VAD where non-speech gaps are simply ignored
|
103 |
-
skip_gaps = self.
|
104 |
-
result = skip_gaps.transcribe(audio_path, whisperCallable)
|
105 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
106 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
107 |
-
expand_gaps = self.
|
108 |
-
result = expand_gaps.transcribe(audio_path, whisperCallable)
|
109 |
elif (vad == 'periodic-vad'):
|
110 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
111 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
112 |
-
periodic_vad = VadPeriodicTranscription(
|
113 |
-
result = periodic_vad.transcribe(audio_path, whisperCallable)
|
114 |
else:
|
115 |
# Default VAD
|
116 |
result = whisperCallable(audio_path, None, None)
|
117 |
|
118 |
return result
|
119 |
|
120 |
-
def
|
121 |
# Use Silero VAD
|
122 |
if (self.vad_model is None):
|
123 |
self.vad_model = VadSileroTranscription()
|
124 |
|
125 |
-
|
126 |
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
|
127 |
segment_padding_left=vadPadding, segment_padding_right=vadPadding,
|
128 |
-
max_prompt_window=vadPromptWindow
|
129 |
|
130 |
-
return
|
131 |
|
132 |
def write_result(self, result: dict, source_name: str, output_dir: str):
|
133 |
if not os.path.exists(output_dir):
|
|
|
14 |
|
15 |
from src.download import ExceededMaximumDuration, download_url
|
16 |
from src.utils import slugify, write_srt, write_vtt
|
17 |
+
from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
18 |
|
19 |
# Limitations (set to -1 to disable)
|
20 |
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
|
|
96 |
# The results
|
97 |
if (vad == 'silero-vad'):
|
98 |
# Silero VAD where non-speech gaps are transcribed
|
99 |
+
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
100 |
+
result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
|
101 |
elif (vad == 'silero-vad-skip-gaps'):
|
102 |
# Silero VAD where non-speech gaps are simply ignored
|
103 |
+
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
104 |
+
result = skip_gaps.transcribe(audio_path, whisperCallable, skip_gaps)
|
105 |
elif (vad == 'silero-vad-expand-into-gaps'):
|
106 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
107 |
+
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
108 |
+
result = expand_gaps.transcribe(audio_path, whisperCallable, expand_gaps)
|
109 |
elif (vad == 'periodic-vad'):
|
110 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
111 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
112 |
+
periodic_vad = VadPeriodicTranscription()
|
113 |
+
result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
|
114 |
else:
|
115 |
# Default VAD
|
116 |
result = whisperCallable(audio_path, None, None)
|
117 |
|
118 |
return result
|
119 |
|
120 |
+
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
|
121 |
# Use Silero VAD
|
122 |
if (self.vad_model is None):
|
123 |
self.vad_model = VadSileroTranscription()
|
124 |
|
125 |
+
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
126 |
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
|
127 |
segment_padding_left=vadPadding, segment_padding_right=vadPadding,
|
128 |
+
max_prompt_window=vadPromptWindow)
|
129 |
|
130 |
+
return config
|
131 |
|
132 |
def write_result(self, result: dict, source_name: str, output_dir: str):
|
133 |
if not os.path.exists(output_dir):
|
src/segments.py
CHANGED
@@ -7,6 +7,13 @@ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5,
|
|
7 |
|
8 |
if len(timestamps) == 0:
|
9 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
processed_time = 0
|
12 |
current_segment = None
|
@@ -17,7 +24,8 @@ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5,
|
|
17 |
delta = next_segment['start'] - processed_time
|
18 |
|
19 |
# Note that segments can still be longer than the max merge size, they just won't be merged in that case
|
20 |
-
if current_segment is None or
|
|
|
21 |
# Finish the current segment
|
22 |
if current_segment is not None:
|
23 |
# Add right padding
|
|
|
7 |
|
8 |
if len(timestamps) == 0:
|
9 |
return result
|
10 |
+
if max_merge_size is None:
|
11 |
+
return timestamps
|
12 |
+
|
13 |
+
if padding_left is None:
|
14 |
+
padding_left = 0
|
15 |
+
if padding_right is None:
|
16 |
+
padding_right = 0
|
17 |
|
18 |
processed_time = 0
|
19 |
current_segment = None
|
|
|
24 |
delta = next_segment['start'] - processed_time
|
25 |
|
26 |
# Note that segments can still be longer than the max merge size, they just won't be merged in that case
|
27 |
+
if current_segment is None or (merge_window is not None and delta > merge_window) \
|
28 |
+
or next_segment['end'] - current_segment['start'] > max_merge_size:
|
29 |
# Finish the current segment
|
30 |
if current_segment is not None:
|
31 |
# Add right padding
|
src/vad.py
CHANGED
@@ -38,45 +38,43 @@ class NonSpeechStrategy(Enum):
|
|
38 |
|
39 |
# Defaults for Silero
|
40 |
SPEECH_TRESHOLD = 0.3
|
41 |
-
MAX_SILENT_PERIOD = 10 # seconds
|
42 |
-
MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
|
43 |
-
|
44 |
-
# Default segment padding
|
45 |
-
SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
|
46 |
-
SEGMENT_PADDING_RIGHT = 1 # End detected segments late
|
47 |
|
48 |
# Minimum size of segments to process
|
49 |
MIN_SEGMENT_DURATION = 1
|
50 |
|
51 |
-
# Always merge segments that are less than this duration apart
|
52 |
-
MIN_FORCE_MERGE_GAP = 0.5
|
53 |
-
FORCE_MERGE_SEGMENT_MULTIPLIER = 1.5
|
54 |
-
|
55 |
# The maximum time for texts from old segments to be used in the next segment
|
56 |
MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
|
57 |
PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
|
58 |
|
59 |
VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
60 |
|
61 |
-
class
|
62 |
-
def __init__(self,
|
63 |
-
|
64 |
-
|
|
|
65 |
self.segment_padding_left = segment_padding_left
|
66 |
self.segment_padding_right = segment_padding_right
|
67 |
self.max_silent_period = max_silent_period
|
68 |
self.max_merge_size = max_merge_size
|
69 |
-
self.non_speech_strategy = non_speech_strategy
|
70 |
self.max_prompt_window = max_prompt_window
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
76 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
77 |
|
78 |
@abstractmethod
|
79 |
-
def get_transcribe_timestamps(self, audio: str):
|
80 |
"""
|
81 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
82 |
|
@@ -84,6 +82,8 @@ class AbstractTranscription(ABC):
|
|
84 |
----------
|
85 |
audio: str
|
86 |
The audio file.
|
|
|
|
|
87 |
|
88 |
Returns
|
89 |
-------
|
@@ -91,7 +91,7 @@ class AbstractTranscription(ABC):
|
|
91 |
"""
|
92 |
return
|
93 |
|
94 |
-
def transcribe(self, audio: str, whisperCallable):
|
95 |
"""
|
96 |
Transcribe the given audo file.
|
97 |
|
@@ -110,12 +110,12 @@ class AbstractTranscription(ABC):
|
|
110 |
"""
|
111 |
|
112 |
# get speech timestamps from full audio file
|
113 |
-
seconds_timestamps = self.get_transcribe_timestamps(audio)
|
114 |
|
115 |
#for seconds_timestamp in seconds_timestamps:
|
116 |
# print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
|
117 |
|
118 |
-
merged = merge_timestamps(seconds_timestamps,
|
119 |
|
120 |
# A deque of transcribed segments that is passed to the next segment as a prompt
|
121 |
prompt_window = deque()
|
@@ -123,18 +123,18 @@ class AbstractTranscription(ABC):
|
|
123 |
print("Timestamps:")
|
124 |
pprint(merged)
|
125 |
|
126 |
-
if
|
127 |
max_audio_duration = get_audio_duration(audio)
|
128 |
|
129 |
# Expand segments to include the gaps between them
|
130 |
-
if (
|
131 |
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
132 |
-
merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=
|
133 |
-
elif
|
134 |
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
135 |
merged = self.expand_gaps(merged, total_duration=max_audio_duration)
|
136 |
else:
|
137 |
-
raise Exception("Unknown non-speech strategy: " + str(
|
138 |
|
139 |
print("Transcribing non-speech:")
|
140 |
pprint(merged)
|
@@ -193,15 +193,15 @@ class AbstractTranscription(ABC):
|
|
193 |
languageCounter[segment_result['language']] += 1
|
194 |
|
195 |
# Update prompt window
|
196 |
-
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap)
|
197 |
|
198 |
if detected_language is not None:
|
199 |
result['language'] = detected_language
|
200 |
|
201 |
return result
|
202 |
|
203 |
-
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool
|
204 |
-
if (
|
205 |
# Add segments to the current prompt window (unless it is a speech gap)
|
206 |
if not segment_gap:
|
207 |
for segment in adjusted_segments:
|
@@ -213,7 +213,7 @@ class AbstractTranscription(ABC):
|
|
213 |
# Time expanded in the segments should be discounted from the prompt window
|
214 |
first_expand_time = prompt_window[0].get('expand_amount', 0)
|
215 |
|
216 |
-
if (first_end_time - first_expand_time < segment_end -
|
217 |
prompt_window.popleft()
|
218 |
else:
|
219 |
break
|
@@ -371,20 +371,14 @@ class AbstractTranscription(ABC):
|
|
371 |
return result
|
372 |
|
373 |
class VadSileroTranscription(AbstractTranscription):
|
374 |
-
def __init__(self,
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
self.get_speech_timestamps = copy.get_speech_timestamps
|
383 |
-
else:
|
384 |
-
self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
385 |
-
(self.get_speech_timestamps, _, _, _, _) = utils
|
386 |
-
|
387 |
-
def get_transcribe_timestamps(self, audio: str):
|
388 |
audio_duration = get_audio_duration(audio)
|
389 |
result = []
|
390 |
|
@@ -410,11 +404,10 @@ class VadSileroTranscription(AbstractTranscription):
|
|
410 |
|
411 |
# A very simple VAD that just marks every N seconds as speech
|
412 |
class VadPeriodicTranscription(AbstractTranscription):
|
413 |
-
def __init__(self,
|
414 |
-
super().__init__()
|
415 |
-
self.periodic_duration = periodic_duration
|
416 |
|
417 |
-
def get_transcribe_timestamps(self, audio: str):
|
418 |
# Get duration in seconds
|
419 |
audio_duration = get_audio_duration(audio)
|
420 |
result = []
|
@@ -423,7 +416,7 @@ class VadPeriodicTranscription(AbstractTranscription):
|
|
423 |
start_timestamp = 0
|
424 |
|
425 |
while (start_timestamp < audio_duration):
|
426 |
-
end_timestamp = min(start_timestamp +
|
427 |
segment_duration = end_timestamp - start_timestamp
|
428 |
|
429 |
# Minimum duration is 1 second
|
|
|
38 |
|
39 |
# Defaults for Silero
|
40 |
SPEECH_TRESHOLD = 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# Minimum size of segments to process
|
43 |
MIN_SEGMENT_DURATION = 1
|
44 |
|
|
|
|
|
|
|
|
|
45 |
# The maximum time for texts from old segments to be used in the next segment
|
46 |
MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
|
47 |
PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
|
48 |
|
49 |
VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
50 |
|
51 |
+
class TranscriptionConfig(ABC):
|
52 |
+
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
53 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
54 |
+
max_merge_size: float = None, max_prompt_window: float = None):
|
55 |
+
self.non_speech_strategy = non_speech_strategy
|
56 |
self.segment_padding_left = segment_padding_left
|
57 |
self.segment_padding_right = segment_padding_right
|
58 |
self.max_silent_period = max_silent_period
|
59 |
self.max_merge_size = max_merge_size
|
|
|
60 |
self.max_prompt_window = max_prompt_window
|
61 |
|
62 |
+
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
63 |
+
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
64 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
65 |
+
max_merge_size: float = None, max_prompt_window: float = None):
|
66 |
+
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
|
67 |
+
self.periodic_duration = periodic_duration
|
68 |
+
|
69 |
+
class AbstractTranscription(ABC):
|
70 |
+
def __init__(self, sampling_rate: int = 16000):
|
71 |
+
self.sampling_rate = sampling_rate
|
72 |
|
73 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
74 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
75 |
|
76 |
@abstractmethod
|
77 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
|
78 |
"""
|
79 |
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
80 |
|
|
|
82 |
----------
|
83 |
audio: str
|
84 |
The audio file.
|
85 |
+
config: TranscriptionConfig
|
86 |
+
The transcription configuration.
|
87 |
|
88 |
Returns
|
89 |
-------
|
|
|
91 |
"""
|
92 |
return
|
93 |
|
94 |
+
def transcribe(self, audio: str, whisperCallable, config: TranscriptionConfig):
|
95 |
"""
|
96 |
Transcribe the given audo file.
|
97 |
|
|
|
110 |
"""
|
111 |
|
112 |
# get speech timestamps from full audio file
|
113 |
+
seconds_timestamps = self.get_transcribe_timestamps(audio, config)
|
114 |
|
115 |
#for seconds_timestamp in seconds_timestamps:
|
116 |
# print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
|
117 |
|
118 |
+
merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
|
119 |
|
120 |
# A deque of transcribed segments that is passed to the next segment as a prompt
|
121 |
prompt_window = deque()
|
|
|
123 |
print("Timestamps:")
|
124 |
pprint(merged)
|
125 |
|
126 |
+
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
127 |
max_audio_duration = get_audio_duration(audio)
|
128 |
|
129 |
# Expand segments to include the gaps between them
|
130 |
+
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
|
131 |
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
132 |
+
merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=config.max_merge_size)
|
133 |
+
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
|
134 |
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
135 |
merged = self.expand_gaps(merged, total_duration=max_audio_duration)
|
136 |
else:
|
137 |
+
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
|
138 |
|
139 |
print("Transcribing non-speech:")
|
140 |
pprint(merged)
|
|
|
193 |
languageCounter[segment_result['language']] += 1
|
194 |
|
195 |
# Update prompt window
|
196 |
+
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
197 |
|
198 |
if detected_language is not None:
|
199 |
result['language'] = detected_language
|
200 |
|
201 |
return result
|
202 |
|
203 |
+
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
204 |
+
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
205 |
# Add segments to the current prompt window (unless it is a speech gap)
|
206 |
if not segment_gap:
|
207 |
for segment in adjusted_segments:
|
|
|
213 |
# Time expanded in the segments should be discounted from the prompt window
|
214 |
first_expand_time = prompt_window[0].get('expand_amount', 0)
|
215 |
|
216 |
+
if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
|
217 |
prompt_window.popleft()
|
218 |
else:
|
219 |
break
|
|
|
371 |
return result
|
372 |
|
373 |
class VadSileroTranscription(AbstractTranscription):
|
374 |
+
def __init__(self, sampling_rate: int = 16000):
|
375 |
+
super().__init__(sampling_rate=sampling_rate)
|
376 |
+
|
377 |
+
self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
378 |
+
(self.get_speech_timestamps, _, _, _, _) = utils
|
379 |
+
|
380 |
+
|
381 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
audio_duration = get_audio_duration(audio)
|
383 |
result = []
|
384 |
|
|
|
404 |
|
405 |
# A very simple VAD that just marks every N seconds as speech
|
406 |
class VadPeriodicTranscription(AbstractTranscription):
|
407 |
+
def __init__(self, sampling_rate: int = 16000):
|
408 |
+
super().__init__(sampling_rate=sampling_rate)
|
|
|
409 |
|
410 |
+
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
|
411 |
# Get duration in seconds
|
412 |
audio_duration = get_audio_duration(audio)
|
413 |
result = []
|
|
|
416 |
start_timestamp = 0
|
417 |
|
418 |
while (start_timestamp < audio_duration):
|
419 |
+
end_timestamp = min(start_timestamp + config.periodic_duration, audio_duration)
|
420 |
segment_duration = end_timestamp - start_timestamp
|
421 |
|
422 |
# Minimum duration is 1 second
|