Adding JSON initial prompt
Browse filesBy selecting "json_prompt_mode", you can
customize the prompt to each segment.
For instance:
[
{"segment_index": 0, "prompt": "Hello, how are you?"},
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
]
- app.py +16 -4
- cli.py +2 -2
- src/config.py +5 -0
- src/prompts/abstractPromptStrategy.py +73 -0
- src/prompts/jsonPromptStrategy.py +48 -0
- src/prompts/prependPromptStrategy.py +31 -0
- src/whisper/abstractWhisperContainer.py +9 -24
- src/whisper/fasterWhisperContainer.py +14 -12
- src/whisper/whisperContainer.py +18 -13
app.py
CHANGED
@@ -13,12 +13,14 @@ import numpy as np
|
|
13 |
|
14 |
import torch
|
15 |
|
16 |
-
from src.config import ApplicationConfig, VadInitialPromptMode
|
17 |
from src.hooks.progressListener import ProgressListener
|
18 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
19 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
20 |
from src.languages import get_language_names
|
21 |
from src.modelCache import ModelCache
|
|
|
|
|
22 |
from src.source import get_audio_source_collection
|
23 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
24 |
|
@@ -271,8 +273,18 @@ class WhisperTranscriber:
|
|
271 |
if ('task' in decodeOptions):
|
272 |
task = decodeOptions.pop('task')
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
# Callable for processing an audio file
|
275 |
-
whisperCallable = model.create_callback(language, task,
|
276 |
|
277 |
# The results
|
278 |
if (vadOptions.vad == 'silero-vad'):
|
@@ -519,7 +531,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
519 |
*common_vad_inputs(),
|
520 |
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
521 |
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
522 |
-
gr.Dropdown(choices=
|
523 |
|
524 |
*common_word_timestamps_inputs(),
|
525 |
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
@@ -580,7 +592,7 @@ if __name__ == '__main__':
|
|
580 |
help="The default model name.") # medium
|
581 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
582 |
help="The default VAD.") # silero-vad
|
583 |
-
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=
|
584 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
585 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
586 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
|
|
13 |
|
14 |
import torch
|
15 |
|
16 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
17 |
from src.hooks.progressListener import ProgressListener
|
18 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
19 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
20 |
from src.languages import get_language_names
|
21 |
from src.modelCache import ModelCache
|
22 |
+
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
23 |
+
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
24 |
from src.source import get_audio_source_collection
|
25 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
26 |
|
|
|
273 |
if ('task' in decodeOptions):
|
274 |
task = decodeOptions.pop('task')
|
275 |
|
276 |
+
if (vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
|
277 |
+
vadOptions.vadInitialPromptMode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
278 |
+
# Prepend initial prompt
|
279 |
+
prompt_strategy = PrependPromptStrategy(initial_prompt, vadOptions.vadInitialPromptMode)
|
280 |
+
elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
|
281 |
+
# Use a JSON format to specify the prompt for each segment
|
282 |
+
prompt_strategy = JsonPromptStrategy(initial_prompt)
|
283 |
+
else:
|
284 |
+
raise ValueError("Invalid vadInitialPromptMode: " + vadOptions.vadInitialPromptMode)
|
285 |
+
|
286 |
# Callable for processing an audio file
|
287 |
+
whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
|
288 |
|
289 |
# The results
|
290 |
if (vadOptions.vad == 'silero-vad'):
|
|
|
531 |
*common_vad_inputs(),
|
532 |
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
533 |
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
534 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
|
535 |
|
536 |
*common_word_timestamps_inputs(),
|
537 |
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
|
|
592 |
help="The default model name.") # medium
|
593 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
594 |
help="The default VAD.") # silero-vad
|
595 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
596 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
597 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
598 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
cli.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
|
8 |
import torch
|
9 |
from app import VadOptions, WhisperTranscriber
|
10 |
-
from src.config import ApplicationConfig, VadInitialPromptMode
|
11 |
from src.download import download_url
|
12 |
from src.languages import get_language_names
|
13 |
|
@@ -47,7 +47,7 @@ def cli():
|
|
47 |
|
48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
49 |
help="The voice activity detection algorithm to use") # silero-vad
|
50 |
-
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=
|
51 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
52 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
53 |
help="The window size (in seconds) to merge voice segments")
|
|
|
7 |
|
8 |
import torch
|
9 |
from app import VadOptions, WhisperTranscriber
|
10 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
11 |
from src.download import download_url
|
12 |
from src.languages import get_language_names
|
13 |
|
|
|
47 |
|
48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
49 |
help="The voice activity detection algorithm to use") # silero-vad
|
50 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
51 |
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
52 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
53 |
help="The window size (in seconds) to merge voice segments")
|
src/config.py
CHANGED
@@ -24,9 +24,12 @@ class ModelConfig:
|
|
24 |
self.path = path
|
25 |
self.type = type
|
26 |
|
|
|
|
|
27 |
class VadInitialPromptMode(Enum):
|
28 |
PREPEND_ALL_SEGMENTS = 1
|
29 |
PREPREND_FIRST_SEGMENT = 2
|
|
|
30 |
|
31 |
@staticmethod
|
32 |
def from_string(s: str):
|
@@ -36,6 +39,8 @@ class VadInitialPromptMode(Enum):
|
|
36 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
37 |
elif normalized == "prepend_first_segment":
|
38 |
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
|
|
|
|
39 |
else:
|
40 |
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
41 |
|
|
|
24 |
self.path = path
|
25 |
self.type = type
|
26 |
|
27 |
+
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
28 |
+
|
29 |
class VadInitialPromptMode(Enum):
|
30 |
PREPEND_ALL_SEGMENTS = 1
|
31 |
PREPREND_FIRST_SEGMENT = 2
|
32 |
+
JSON_PROMPT_MODE = 3
|
33 |
|
34 |
@staticmethod
|
35 |
def from_string(s: str):
|
|
|
39 |
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
40 |
elif normalized == "prepend_first_segment":
|
41 |
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
42 |
+
elif normalized == "json_prompt_mode":
|
43 |
+
return VadInitialPromptMode.JSON_PROMPT_MODE
|
44 |
else:
|
45 |
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
46 |
|
src/prompts/abstractPromptStrategy.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
|
4 |
+
class AbstractPromptStrategy:
|
5 |
+
"""
|
6 |
+
Represents a strategy for generating prompts for a given audio segment.
|
7 |
+
|
8 |
+
Note that the strategy must be picklable, as it will be serialized and sent to the workers.
|
9 |
+
"""
|
10 |
+
|
11 |
+
@abc.abstractmethod
|
12 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
13 |
+
"""
|
14 |
+
Retrieves the prompt for a given segment.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
segment_index: int
|
19 |
+
The index of the segment.
|
20 |
+
whisper_prompt: str
|
21 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
22 |
+
detected_language: str
|
23 |
+
The language detected for the segment.
|
24 |
+
"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
@abc.abstractmethod
|
28 |
+
def on_segment_finished(self, segment_index: int, whisper_prompt: str, detected_language: str, result: dict):
|
29 |
+
"""
|
30 |
+
Called when a segment has finished processing.
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
segment_index: int
|
35 |
+
The index of the segment.
|
36 |
+
whisper_prompt: str
|
37 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
38 |
+
detected_language: str
|
39 |
+
The language detected for the segment.
|
40 |
+
result: dict
|
41 |
+
The result of the segment. It has the following format:
|
42 |
+
{
|
43 |
+
"text": str,
|
44 |
+
"segments": [
|
45 |
+
{
|
46 |
+
"text": str,
|
47 |
+
"start": float,
|
48 |
+
"end": float,
|
49 |
+
"words": [words],
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"language": str,
|
53 |
+
}
|
54 |
+
"""
|
55 |
+
pass
|
56 |
+
|
57 |
+
def _concat_prompt(self, prompt1, prompt2):
|
58 |
+
"""
|
59 |
+
Concatenates two prompts.
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
prompt1: str
|
64 |
+
The first prompt.
|
65 |
+
prompt2: str
|
66 |
+
The second prompt.
|
67 |
+
"""
|
68 |
+
if (prompt1 is None):
|
69 |
+
return prompt2
|
70 |
+
elif (prompt2 is None):
|
71 |
+
return prompt1
|
72 |
+
else:
|
73 |
+
return prompt1 + " " + prompt2
|
src/prompts/jsonPromptStrategy.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
3 |
+
|
4 |
+
|
5 |
+
class JsonPromptSegment():
|
6 |
+
def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
|
7 |
+
self.prompt = prompt
|
8 |
+
self.segment_index = segment_index
|
9 |
+
self.format_prompt = format_prompt
|
10 |
+
|
11 |
+
class JsonPromptStrategy(AbstractPromptStrategy):
|
12 |
+
def __init__(self, initial_json_prompt: str):
|
13 |
+
"""
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
initial_json_prompt: str
|
17 |
+
The initial prompts for each segment in JSON form.
|
18 |
+
|
19 |
+
Format:
|
20 |
+
[
|
21 |
+
{"segment_index": 0, "prompt": "Hello, how are you?"},
|
22 |
+
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
|
23 |
+
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
|
24 |
+
]
|
25 |
+
|
26 |
+
"""
|
27 |
+
parsed_json = json.loads(initial_json_prompt)
|
28 |
+
self.segment_lookup = dict[str, JsonPromptSegment]()
|
29 |
+
|
30 |
+
for prompt_entry in parsed_json:
|
31 |
+
segment_index = prompt_entry["segment_index"]
|
32 |
+
prompt = prompt_entry["prompt"]
|
33 |
+
format_prompt = prompt_entry.get("format_prompt", False)
|
34 |
+
self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
|
35 |
+
|
36 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
37 |
+
# Lookup prompt
|
38 |
+
prompt = self.segment_lookup.get(str(segment_index), None)
|
39 |
+
|
40 |
+
if (prompt is None):
|
41 |
+
# No prompt found, return whisper prompt
|
42 |
+
print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
|
43 |
+
return whisper_prompt
|
44 |
+
|
45 |
+
if (prompt.format_prompt):
|
46 |
+
return prompt.prompt.format(whisper_prompt)
|
47 |
+
else:
|
48 |
+
return self._concat_prompt(prompt.prompt, whisper_prompt)
|
src/prompts/prependPromptStrategy.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config import VadInitialPromptMode
|
2 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
3 |
+
|
4 |
+
class PrependPromptStrategy(AbstractPromptStrategy):
|
5 |
+
"""
|
6 |
+
A simple prompt strategy that prepends a single prompt to all segments of audio, or prepends the prompt to the first segment of audio.
|
7 |
+
"""
|
8 |
+
def __init__(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode):
|
9 |
+
"""
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
initial_prompt: str
|
13 |
+
The initial prompt to use for the transcription.
|
14 |
+
initial_prompt_mode: VadInitialPromptMode
|
15 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
16 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
17 |
+
"""
|
18 |
+
self.initial_prompt = initial_prompt
|
19 |
+
self.initial_prompt_mode = initial_prompt_mode
|
20 |
+
|
21 |
+
# This is a simple prompt strategy, so we only support these two modes
|
22 |
+
if initial_prompt_mode not in [VadInitialPromptMode.PREPEND_ALL_SEGMENTS, VadInitialPromptMode.PREPREND_FIRST_SEGMENT]:
|
23 |
+
raise ValueError(f"Unsupported initial prompt mode {initial_prompt_mode}")
|
24 |
+
|
25 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
26 |
+
if (self.initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
27 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt)
|
28 |
+
elif (self.initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
29 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt) if segment_index == 0 else whisper_prompt
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unknown initial prompt mode {self.initial_prompt_mode}")
|
src/whisper/abstractWhisperContainer.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import abc
|
2 |
from typing import List
|
|
|
3 |
from src.config import ModelConfig, VadInitialPromptMode
|
4 |
|
5 |
from src.hooks.progressListener import ProgressListener
|
6 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
7 |
|
8 |
class AbstractWhisperCallback:
|
|
|
|
|
|
|
9 |
@abc.abstractmethod
|
10 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
11 |
"""
|
@@ -24,23 +29,6 @@ class AbstractWhisperCallback:
|
|
24 |
"""
|
25 |
raise NotImplementedError()
|
26 |
|
27 |
-
def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
|
28 |
-
prompt: str, segment_index: int):
|
29 |
-
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
30 |
-
return self._concat_prompt(initial_prompt, prompt)
|
31 |
-
elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
32 |
-
return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
|
33 |
-
else:
|
34 |
-
raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
|
35 |
-
|
36 |
-
def _concat_prompt(self, prompt1, prompt2):
|
37 |
-
if (prompt1 is None):
|
38 |
-
return prompt2
|
39 |
-
elif (prompt2 is None):
|
40 |
-
return prompt1
|
41 |
-
else:
|
42 |
-
return prompt1 + " " + prompt2
|
43 |
-
|
44 |
class AbstractWhisperContainer:
|
45 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
46 |
download_root: str = None,
|
@@ -75,8 +63,8 @@ class AbstractWhisperContainer:
|
|
75 |
pass
|
76 |
|
77 |
@abc.abstractmethod
|
78 |
-
def create_callback(self, language: str = None, task: str = None,
|
79 |
-
|
80 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
81 |
"""
|
82 |
Create a WhisperCallback object that can be used to transcript audio files.
|
@@ -87,11 +75,8 @@ class AbstractWhisperContainer:
|
|
87 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
88 |
task: str
|
89 |
The task - either translate or transcribe.
|
90 |
-
|
91 |
-
The
|
92 |
-
initial_prompt_mode: VadInitialPromptMode
|
93 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
94 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
95 |
decodeOptions: dict
|
96 |
Additional options to pass to the decoder. Must be pickleable.
|
97 |
|
|
|
1 |
import abc
|
2 |
from typing import List
|
3 |
+
|
4 |
from src.config import ModelConfig, VadInitialPromptMode
|
5 |
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
8 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
9 |
|
10 |
class AbstractWhisperCallback:
|
11 |
+
def __init__(self):
|
12 |
+
self.__prompt_mode_gpt = None
|
13 |
+
|
14 |
@abc.abstractmethod
|
15 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
16 |
"""
|
|
|
29 |
"""
|
30 |
raise NotImplementedError()
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class AbstractWhisperContainer:
|
33 |
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
34 |
download_root: str = None,
|
|
|
63 |
pass
|
64 |
|
65 |
@abc.abstractmethod
|
66 |
+
def create_callback(self, language: str = None, task: str = None,
|
67 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
68 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
69 |
"""
|
70 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
75 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
76 |
task: str
|
77 |
The task - either translate or transcribe.
|
78 |
+
prompt_strategy: AbstractPromptStrategy
|
79 |
+
The prompt strategy to use for the transcription.
|
|
|
|
|
|
|
80 |
decodeOptions: dict
|
81 |
Additional options to pass to the decoder. Must be pickleable.
|
82 |
|
src/whisper/fasterWhisperContainer.py
CHANGED
@@ -6,6 +6,7 @@ from src.config import ModelConfig, VadInitialPromptMode
|
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
from src.languages import get_language_from_name
|
8 |
from src.modelCache import ModelCache
|
|
|
9 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
10 |
from src.utils import format_timestamp
|
11 |
|
@@ -56,8 +57,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
56 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
57 |
return model
|
58 |
|
59 |
-
def create_callback(self, language: str = None, task: str = None,
|
60 |
-
|
61 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
62 |
"""
|
63 |
Create a WhisperCallback object that can be used to transcript audio files.
|
@@ -68,11 +69,8 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
68 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
69 |
task: str
|
70 |
The task - either translate or transcribe.
|
71 |
-
|
72 |
-
The
|
73 |
-
initial_prompt_mode: VadInitialPromptMode
|
74 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
75 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
76 |
decodeOptions: dict
|
77 |
Additional options to pass to the decoder. Must be pickleable.
|
78 |
|
@@ -80,17 +78,16 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
80 |
-------
|
81 |
A WhisperCallback object.
|
82 |
"""
|
83 |
-
return FasterWhisperCallback(self, language=language, task=task,
|
84 |
|
85 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
86 |
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
87 |
-
|
88 |
**decodeOptions: dict):
|
89 |
self.model_container = model_container
|
90 |
self.language = language
|
91 |
self.task = task
|
92 |
-
self.
|
93 |
-
self.initial_prompt_mode = initial_prompt_mode
|
94 |
self.decodeOptions = decodeOptions
|
95 |
|
96 |
self._printed_warning = False
|
@@ -138,7 +135,8 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
138 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
139 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
140 |
|
141 |
-
initial_prompt = self.
|
|
|
142 |
|
143 |
segments_generator, info = model.transcribe(audio, \
|
144 |
language=language_code if language_code else detected_language, task=self.task, \
|
@@ -184,6 +182,10 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
184 |
"duration": info.duration if info else None
|
185 |
}
|
186 |
|
|
|
|
|
|
|
|
|
187 |
if progress_listener is not None:
|
188 |
progress_listener.on_finished()
|
189 |
return result
|
|
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
from src.languages import get_language_from_name
|
8 |
from src.modelCache import ModelCache
|
9 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
10 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
11 |
from src.utils import format_timestamp
|
12 |
|
|
|
57 |
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
58 |
return model
|
59 |
|
60 |
+
def create_callback(self, language: str = None, task: str = None,
|
61 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
62 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
63 |
"""
|
64 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
69 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
70 |
task: str
|
71 |
The task - either translate or transcribe.
|
72 |
+
prompt_strategy: AbstractPromptStrategy
|
73 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
|
|
|
|
|
|
74 |
decodeOptions: dict
|
75 |
Additional options to pass to the decoder. Must be pickleable.
|
76 |
|
|
|
78 |
-------
|
79 |
A WhisperCallback object.
|
80 |
"""
|
81 |
+
return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
82 |
|
83 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
84 |
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
85 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
86 |
**decodeOptions: dict):
|
87 |
self.model_container = model_container
|
88 |
self.language = language
|
89 |
self.task = task
|
90 |
+
self.prompt_strategy = prompt_strategy
|
|
|
91 |
self.decodeOptions = decodeOptions
|
92 |
|
93 |
self._printed_warning = False
|
|
|
135 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
136 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
137 |
|
138 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
139 |
+
if self.prompt_strategy else prompt
|
140 |
|
141 |
segments_generator, info = model.transcribe(audio, \
|
142 |
language=language_code if language_code else detected_language, task=self.task, \
|
|
|
182 |
"duration": info.duration if info else None
|
183 |
}
|
184 |
|
185 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
186 |
+
if self.prompt_strategy:
|
187 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
188 |
+
|
189 |
if progress_listener is not None:
|
190 |
progress_listener.on_finished()
|
191 |
return result
|
src/whisper/whisperContainer.py
CHANGED
@@ -15,6 +15,7 @@ from src.config import ModelConfig, VadInitialPromptMode
|
|
15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
16 |
|
17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
18 |
from src.utils import download_file
|
19 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
20 |
|
@@ -69,8 +70,8 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
69 |
|
70 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
71 |
|
72 |
-
def create_callback(self, language: str = None, task: str = None,
|
73 |
-
|
74 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
75 |
"""
|
76 |
Create a WhisperCallback object that can be used to transcript audio files.
|
@@ -81,11 +82,8 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
81 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
82 |
task: str
|
83 |
The task - either translate or transcribe.
|
84 |
-
|
85 |
-
The
|
86 |
-
initial_prompt_mode: VadInitialPromptMode
|
87 |
-
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
88 |
-
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
89 |
decodeOptions: dict
|
90 |
Additional options to pass to the decoder. Must be pickleable.
|
91 |
|
@@ -93,7 +91,7 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
93 |
-------
|
94 |
A WhisperCallback object.
|
95 |
"""
|
96 |
-
return WhisperCallback(self, language=language, task=task,
|
97 |
|
98 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
99 |
from src.conversion.hf_converter import convert_hf_whisper
|
@@ -162,13 +160,14 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
162 |
return model_config.path
|
163 |
|
164 |
class WhisperCallback(AbstractWhisperCallback):
|
165 |
-
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
|
166 |
-
|
|
|
167 |
self.model_container = model_container
|
168 |
self.language = language
|
169 |
self.task = task
|
170 |
-
self.
|
171 |
-
|
172 |
self.decodeOptions = decodeOptions
|
173 |
|
174 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
@@ -201,11 +200,17 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
201 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
202 |
decodeOptions["fp16"] = True
|
203 |
|
204 |
-
initial_prompt = self.
|
|
|
205 |
|
206 |
result = model.transcribe(audio, \
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
)
|
|
|
|
|
|
|
|
|
|
|
211 |
return result
|
|
|
15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
16 |
|
17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
18 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
19 |
from src.utils import download_file
|
20 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
21 |
|
|
|
70 |
|
71 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
72 |
|
73 |
+
def create_callback(self, language: str = None, task: str = None,
|
74 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
75 |
**decodeOptions: dict) -> AbstractWhisperCallback:
|
76 |
"""
|
77 |
Create a WhisperCallback object that can be used to transcript audio files.
|
|
|
82 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
83 |
task: str
|
84 |
The task - either translate or transcribe.
|
85 |
+
prompt_strategy: AbstractPromptStrategy
|
86 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
|
|
|
|
|
|
87 |
decodeOptions: dict
|
88 |
Additional options to pass to the decoder. Must be pickleable.
|
89 |
|
|
|
91 |
-------
|
92 |
A WhisperCallback object.
|
93 |
"""
|
94 |
+
return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
95 |
|
96 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
97 |
from src.conversion.hf_converter import convert_hf_whisper
|
|
|
160 |
return model_config.path
|
161 |
|
162 |
class WhisperCallback(AbstractWhisperCallback):
|
163 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
|
164 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
165 |
+
**decodeOptions: dict):
|
166 |
self.model_container = model_container
|
167 |
self.language = language
|
168 |
self.task = task
|
169 |
+
self.prompt_strategy = prompt_strategy
|
170 |
+
|
171 |
self.decodeOptions = decodeOptions
|
172 |
|
173 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
|
200 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
201 |
decodeOptions["fp16"] = True
|
202 |
|
203 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
204 |
+
if self.prompt_strategy else prompt
|
205 |
|
206 |
result = model.transcribe(audio, \
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
)
|
211 |
+
|
212 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
213 |
+
if self.prompt_strategy:
|
214 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
215 |
+
|
216 |
return result
|