Adding support for faster_whisper
Browse filesThis is a re-implementation of Whisper in CTranslate2 that can be 4x faster
and use much less memory than OpenAI's Whisper.
- app.py +12 -5
- cli.py +9 -4
- config.json5 +3 -0
- requirements-fastWhisper.txt +8 -0
- src/config.py +2 -78
- src/conversion/hf_converter.py +2 -2
- src/hooks/progressListener.py +8 -0
- src/hooks/subTaskProgressListener.py +37 -0
- src/hooks/whisperProgressHook.py +1 -41
- src/utils.py +21 -1
- src/vad.py +5 -3
- src/vadParallel.py +5 -4
- src/whisper/abstractWhisperContainer.py +99 -0
- src/whisper/fasterWhisperContainer.py +165 -0
- src/{whisperContainer.py → whisper/whisperContainer.py} +83 -58
- src/whisper/whisperFactory.py +16 -0
app.py
CHANGED
@@ -11,8 +11,11 @@ import zipfile
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
|
|
14 |
from src.config import ApplicationConfig
|
15 |
-
from src.hooks.
|
|
|
|
|
16 |
from src.modelCache import ModelCache
|
17 |
from src.source import get_audio_source_collection
|
18 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
@@ -26,7 +29,8 @@ import gradio as gr
|
|
26 |
from src.download import ExceededMaximumDuration, download_url
|
27 |
from src.utils import slugify, write_srt, write_vtt
|
28 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
29 |
-
from src.
|
|
|
30 |
|
31 |
# Configure more application defaults in config.json5
|
32 |
|
@@ -121,7 +125,8 @@ class WhisperTranscriber:
|
|
121 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
122 |
selectedModel = modelName if modelName is not None else "base"
|
123 |
|
124 |
-
model =
|
|
|
125 |
|
126 |
# Result
|
127 |
download = []
|
@@ -223,7 +228,7 @@ class WhisperTranscriber:
|
|
223 |
except ExceededMaximumDuration as e:
|
224 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
225 |
|
226 |
-
def transcribe_file(self, model:
|
227 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
228 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
229 |
|
@@ -507,7 +512,9 @@ if __name__ == '__main__':
|
|
507 |
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
508 |
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
509 |
parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
|
510 |
-
help="directory to save the outputs")
|
|
|
|
|
511 |
|
512 |
args = parser.parse_args().__dict__
|
513 |
|
|
|
11 |
import numpy as np
|
12 |
|
13 |
import torch
|
14 |
+
|
15 |
from src.config import ApplicationConfig
|
16 |
+
from src.hooks.progressListener import ProgressListener
|
17 |
+
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
18 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
19 |
from src.modelCache import ModelCache
|
20 |
from src.source import get_audio_source_collection
|
21 |
from src.vadParallel import ParallelContext, ParallelTranscription
|
|
|
29 |
from src.download import ExceededMaximumDuration, download_url
|
30 |
from src.utils import slugify, write_srt, write_vtt
|
31 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
32 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
33 |
+
from src.whisper.whisperFactory import create_whisper_container
|
34 |
|
35 |
# Configure more application defaults in config.json5
|
36 |
|
|
|
125 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
126 |
selectedModel = modelName if modelName is not None else "base"
|
127 |
|
128 |
+
model = create_whisper_container(whisper_implementation=app_config.whisper_implementation,
|
129 |
+
model_name=selectedModel, cache=self.model_cache, models=self.app_config.models)
|
130 |
|
131 |
# Result
|
132 |
download = []
|
|
|
228 |
except ExceededMaximumDuration as e:
|
229 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
230 |
|
231 |
+
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
232 |
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
233 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
234 |
|
|
|
512 |
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
513 |
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
514 |
parser.add_argument("--output_dir", "-o", type=str, default=app_config.output_dir, \
|
515 |
+
help="directory to save the outputs"), \
|
516 |
+
parser.add_argument("--whisper_implementation", type=str, default=app_config.whisper_implementation, choices=["whisper", "faster-whisper"],\
|
517 |
+
help="the Whisper implementation to use"), \
|
518 |
|
519 |
args = parser.parse_args().__dict__
|
520 |
|
cli.py
CHANGED
@@ -11,7 +11,7 @@ from src.config import ApplicationConfig
|
|
11 |
from src.download import download_url
|
12 |
|
13 |
from src.utils import optional_float, optional_int, str2bool
|
14 |
-
from src.
|
15 |
|
16 |
def cli():
|
17 |
app_config = ApplicationConfig.create_default()
|
@@ -32,8 +32,10 @@ def cli():
|
|
32 |
parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
|
33 |
help="directory to save the outputs")
|
34 |
parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
|
35 |
-
help="whether to print out the progress and debug messages")
|
36 |
-
|
|
|
|
|
37 |
parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
|
38 |
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
39 |
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
|
@@ -92,6 +94,8 @@ def cli():
|
|
92 |
device: str = args.pop("device")
|
93 |
os.makedirs(output_dir, exist_ok=True)
|
94 |
|
|
|
|
|
95 |
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
96 |
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
97 |
args["language"] = "en"
|
@@ -115,7 +119,8 @@ def cli():
|
|
115 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
116 |
transcriber.set_auto_parallel(auto_parallel)
|
117 |
|
118 |
-
model =
|
|
|
119 |
|
120 |
if (transcriber._has_parallel_devices()):
|
121 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
|
|
11 |
from src.download import download_url
|
12 |
|
13 |
from src.utils import optional_float, optional_int, str2bool
|
14 |
+
from src.whisper.whisperFactory import create_whisper_container
|
15 |
|
16 |
def cli():
|
17 |
app_config = ApplicationConfig.create_default()
|
|
|
32 |
parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
|
33 |
help="directory to save the outputs")
|
34 |
parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
|
35 |
+
help="whether to print out the progress and debug messages"), \
|
36 |
+
parser.add_argument("--whisper_implementation", type=str, default=app_config.whisper_implementation, choices=["whisper", "faster-whisper"],\
|
37 |
+
help="the Whisper implementation to use"), \
|
38 |
+
|
39 |
parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
|
40 |
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
41 |
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
|
|
|
94 |
device: str = args.pop("device")
|
95 |
os.makedirs(output_dir, exist_ok=True)
|
96 |
|
97 |
+
whisper_implementation = args.pop("whisper_implementation")
|
98 |
+
|
99 |
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
100 |
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
101 |
args["language"] = "en"
|
|
|
119 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
120 |
transcriber.set_auto_parallel(auto_parallel)
|
121 |
|
122 |
+
model = create_whisper_container(whisper_implementation=whisper_implementation,
|
123 |
+
device=device, download_root=model_dir, models=app_config.models)
|
124 |
|
125 |
if (transcriber._has_parallel_devices()):
|
126 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
config.json5
CHANGED
@@ -62,6 +62,9 @@
|
|
62 |
|
63 |
// * General options *
|
64 |
|
|
|
|
|
|
|
65 |
// The default model name.
|
66 |
"default_model_name": "medium",
|
67 |
// The default VAD.
|
|
|
62 |
|
63 |
// * General options *
|
64 |
|
65 |
+
// The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
|
66 |
+
"whisper_implementation": "whisper",
|
67 |
+
|
68 |
// The default model name.
|
69 |
"default_model_name": "medium",
|
70 |
// The default VAD.
|
requirements-fastWhisper.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ctranslate2
|
2 |
+
faster-whisper
|
3 |
+
ffmpeg-python==0.2.0
|
4 |
+
gradio==3.23.0
|
5 |
+
yt-dlp
|
6 |
+
json5
|
7 |
+
torch
|
8 |
+
torchaudio
|
src/config.py
CHANGED
@@ -8,8 +8,6 @@ import torch
|
|
8 |
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
-
from src.conversion.hf_converter import convert_hf_whisper
|
12 |
-
|
13 |
class ModelConfig:
|
14 |
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
15 |
"""
|
@@ -25,86 +23,11 @@ class ModelConfig:
|
|
25 |
self.path = path
|
26 |
self.type = type
|
27 |
|
28 |
-
def download_url(self, root_dir: str):
|
29 |
-
import whisper
|
30 |
-
|
31 |
-
# See if path is already set
|
32 |
-
if self.path is not None:
|
33 |
-
return self.path
|
34 |
-
|
35 |
-
if root_dir is None:
|
36 |
-
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
37 |
-
|
38 |
-
model_type = self.type.lower() if self.type is not None else "whisper"
|
39 |
-
|
40 |
-
if model_type in ["huggingface", "hf"]:
|
41 |
-
self.path = self.url
|
42 |
-
destination_target = os.path.join(root_dir, self.name + ".pt")
|
43 |
-
|
44 |
-
# Convert from HuggingFace format to Whisper format
|
45 |
-
if os.path.exists(destination_target):
|
46 |
-
print(f"File {destination_target} already exists, skipping conversion")
|
47 |
-
else:
|
48 |
-
print("Saving HuggingFace model in Whisper format to " + destination_target)
|
49 |
-
convert_hf_whisper(self.url, destination_target)
|
50 |
-
|
51 |
-
self.path = destination_target
|
52 |
-
|
53 |
-
elif model_type in ["whisper", "w"]:
|
54 |
-
self.path = self.url
|
55 |
-
|
56 |
-
# See if URL is just a file
|
57 |
-
if self.url in whisper._MODELS:
|
58 |
-
# No need to download anything - Whisper will handle it
|
59 |
-
self.path = self.url
|
60 |
-
elif self.url.startswith("file://"):
|
61 |
-
# Get file path
|
62 |
-
self.path = urlparse(self.url).path
|
63 |
-
# See if it is an URL
|
64 |
-
elif self.url.startswith("http://") or self.url.startswith("https://"):
|
65 |
-
# Extension (or file name)
|
66 |
-
extension = os.path.splitext(self.url)[-1]
|
67 |
-
download_target = os.path.join(root_dir, self.name + extension)
|
68 |
-
|
69 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
70 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
71 |
-
|
72 |
-
if not os.path.isfile(download_target):
|
73 |
-
self._download_file(self.url, download_target)
|
74 |
-
else:
|
75 |
-
print(f"File {download_target} already exists, skipping download")
|
76 |
-
|
77 |
-
self.path = download_target
|
78 |
-
# Must be a local file
|
79 |
-
else:
|
80 |
-
self.path = self.url
|
81 |
-
|
82 |
-
else:
|
83 |
-
raise ValueError(f"Unknown model type {model_type}")
|
84 |
-
|
85 |
-
return self.path
|
86 |
-
|
87 |
-
def _download_file(self, url: str, destination: str):
|
88 |
-
with urllib.request.urlopen(url) as source, open(destination, "wb") as output:
|
89 |
-
with tqdm(
|
90 |
-
total=int(source.info().get("Content-Length")),
|
91 |
-
ncols=80,
|
92 |
-
unit="iB",
|
93 |
-
unit_scale=True,
|
94 |
-
unit_divisor=1024,
|
95 |
-
) as loop:
|
96 |
-
while True:
|
97 |
-
buffer = source.read(8192)
|
98 |
-
if not buffer:
|
99 |
-
break
|
100 |
-
|
101 |
-
output.write(buffer)
|
102 |
-
loop.update(len(buffer))
|
103 |
-
|
104 |
class ApplicationConfig:
|
105 |
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
106 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
107 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
|
|
108 |
default_model_name: str = "medium", default_vad: str = "silero-vad",
|
109 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
110 |
auto_parallel: bool = False, output_dir: str = None,
|
@@ -132,6 +55,7 @@ class ApplicationConfig:
|
|
132 |
self.queue_concurrency_count = queue_concurrency_count
|
133 |
self.delete_uploaded_files = delete_uploaded_files
|
134 |
|
|
|
135 |
self.default_model_name = default_model_name
|
136 |
self.default_vad = default_vad
|
137 |
self.vad_parallel_devices = vad_parallel_devices
|
|
|
8 |
|
9 |
from tqdm import tqdm
|
10 |
|
|
|
|
|
11 |
class ModelConfig:
|
12 |
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
13 |
"""
|
|
|
23 |
self.path = path
|
24 |
self.type = type
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class ApplicationConfig:
|
27 |
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
28 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
29 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
30 |
+
whisper_implementation: str = "whisper",
|
31 |
default_model_name: str = "medium", default_vad: str = "silero-vad",
|
32 |
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
33 |
auto_parallel: bool = False, output_dir: str = None,
|
|
|
55 |
self.queue_concurrency_count = queue_concurrency_count
|
56 |
self.delete_uploaded_files = delete_uploaded_files
|
57 |
|
58 |
+
self.whisper_implementation = whisper_implementation
|
59 |
self.default_model_name = default_model_name
|
60 |
self.default_vad = default_vad
|
61 |
self.vad_parallel_devices = vad_parallel_devices
|
src/conversion/hf_converter.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
|
3 |
from copy import deepcopy
|
4 |
import torch
|
5 |
-
from transformers import WhisperForConditionalGeneration
|
6 |
|
7 |
WHISPER_MAPPING = {
|
8 |
"layers": "blocks",
|
@@ -43,7 +42,8 @@ def rename_keys(s_dict):
|
|
43 |
return s_dict
|
44 |
|
45 |
|
46 |
-
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str)
|
|
|
47 |
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
48 |
config = transformer_model.config
|
49 |
|
|
|
2 |
|
3 |
from copy import deepcopy
|
4 |
import torch
|
|
|
5 |
|
6 |
WHISPER_MAPPING = {
|
7 |
"layers": "blocks",
|
|
|
42 |
return s_dict
|
43 |
|
44 |
|
45 |
+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str)
|
46 |
+
from transformers import WhisperForConditionalGeneration
|
47 |
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
48 |
config = transformer_model.config
|
49 |
|
src/hooks/progressListener.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
class ProgressListener:
|
4 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
5 |
+
self.total = total
|
6 |
+
|
7 |
+
def on_finished(self):
|
8 |
+
pass
|
src/hooks/subTaskProgressListener.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.hooks.progressListener import ProgressListener
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
class SubTaskProgressListener(ProgressListener):
|
6 |
+
"""
|
7 |
+
A sub task listener that reports the progress of a sub task to a base task listener
|
8 |
+
Parameters
|
9 |
+
----------
|
10 |
+
base_task_listener : ProgressListener
|
11 |
+
The base progress listener to accumulate overall progress in.
|
12 |
+
base_task_total : float
|
13 |
+
The maximum total progress that will be reported to the base progress listener.
|
14 |
+
sub_task_start : float
|
15 |
+
The starting progress of a sub task, in respect to the base progress listener.
|
16 |
+
sub_task_total : float
|
17 |
+
The total amount of progress a sub task will report to the base progress listener.
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
base_task_listener: ProgressListener,
|
22 |
+
base_task_total: float,
|
23 |
+
sub_task_start: float,
|
24 |
+
sub_task_total: float,
|
25 |
+
):
|
26 |
+
self.base_task_listener = base_task_listener
|
27 |
+
self.base_task_total = base_task_total
|
28 |
+
self.sub_task_start = sub_task_start
|
29 |
+
self.sub_task_total = sub_task_total
|
30 |
+
|
31 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
32 |
+
sub_task_progress_frac = current / total
|
33 |
+
sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
|
34 |
+
self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
|
35 |
+
|
36 |
+
def on_finished(self):
|
37 |
+
self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
|
src/hooks/whisperProgressHook.py
CHANGED
@@ -3,12 +3,7 @@ import threading
|
|
3 |
from typing import List, Union
|
4 |
import tqdm
|
5 |
|
6 |
-
|
7 |
-
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
8 |
-
self.total = total
|
9 |
-
|
10 |
-
def on_finished(self):
|
11 |
-
pass
|
12 |
|
13 |
class ProgressListenerHandle:
|
14 |
def __init__(self, listener: ProgressListener):
|
@@ -23,41 +18,6 @@ class ProgressListenerHandle:
|
|
23 |
if exc_type is None:
|
24 |
self.listener.on_finished()
|
25 |
|
26 |
-
class SubTaskProgressListener(ProgressListener):
|
27 |
-
"""
|
28 |
-
A sub task listener that reports the progress of a sub task to a base task listener
|
29 |
-
|
30 |
-
Parameters
|
31 |
-
----------
|
32 |
-
base_task_listener : ProgressListener
|
33 |
-
The base progress listener to accumulate overall progress in.
|
34 |
-
base_task_total : float
|
35 |
-
The maximum total progress that will be reported to the base progress listener.
|
36 |
-
sub_task_start : float
|
37 |
-
The starting progress of a sub task, in respect to the base progress listener.
|
38 |
-
sub_task_total : float
|
39 |
-
The total amount of progress a sub task will report to the base progress listener.
|
40 |
-
"""
|
41 |
-
def __init__(
|
42 |
-
self,
|
43 |
-
base_task_listener: ProgressListener,
|
44 |
-
base_task_total: float,
|
45 |
-
sub_task_start: float,
|
46 |
-
sub_task_total: float,
|
47 |
-
):
|
48 |
-
self.base_task_listener = base_task_listener
|
49 |
-
self.base_task_total = base_task_total
|
50 |
-
self.sub_task_start = sub_task_start
|
51 |
-
self.sub_task_total = sub_task_total
|
52 |
-
|
53 |
-
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
54 |
-
sub_task_progress_frac = current / total
|
55 |
-
sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
|
56 |
-
self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
|
57 |
-
|
58 |
-
def on_finished(self):
|
59 |
-
self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
|
60 |
-
|
61 |
class _CustomProgressBar(tqdm.tqdm):
|
62 |
def __init__(self, *args, **kwargs):
|
63 |
super().__init__(*args, **kwargs)
|
|
|
3 |
from typing import List, Union
|
4 |
import tqdm
|
5 |
|
6 |
+
from src.hooks.progressListener import ProgressListener
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class ProgressListenerHandle:
|
9 |
def __init__(self, listener: ProgressListener):
|
|
|
18 |
if exc_type is None:
|
19 |
self.listener.on_finished()
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class _CustomProgressBar(tqdm.tqdm):
|
22 |
def __init__(self, *args, **kwargs):
|
23 |
super().__init__(*args, **kwargs)
|
src/utils.py
CHANGED
@@ -4,6 +4,9 @@ import re
|
|
4 |
|
5 |
import zlib
|
6 |
from typing import Iterator, TextIO
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def exact_div(x, y):
|
@@ -112,4 +115,21 @@ def slugify(value, allow_unicode=False):
|
|
112 |
else:
|
113 |
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
114 |
value = re.sub(r'[^\w\s-]', '', value.lower())
|
115 |
-
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import zlib
|
6 |
from typing import Iterator, TextIO
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
import urllib3
|
10 |
|
11 |
|
12 |
def exact_div(x, y):
|
|
|
115 |
else:
|
116 |
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
117 |
value = re.sub(r'[^\w\s-]', '', value.lower())
|
118 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
119 |
+
|
120 |
+
def download_file(url: str, destination: str):
|
121 |
+
with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
|
122 |
+
with tqdm(
|
123 |
+
total=int(source.info().get("Content-Length")),
|
124 |
+
ncols=80,
|
125 |
+
unit="iB",
|
126 |
+
unit_scale=True,
|
127 |
+
unit_divisor=1024,
|
128 |
+
) as loop:
|
129 |
+
while True:
|
130 |
+
buffer = source.read(8192)
|
131 |
+
if not buffer:
|
132 |
+
break
|
133 |
+
|
134 |
+
output.write(buffer)
|
135 |
+
loop.update(len(buffer))
|
src/vad.py
CHANGED
@@ -5,11 +5,13 @@ import time
|
|
5 |
from typing import Any, Deque, Iterator, List, Dict
|
6 |
|
7 |
from pprint import pprint
|
8 |
-
from src.hooks.
|
|
|
|
|
9 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
10 |
|
11 |
from src.segments import merge_timestamps
|
12 |
-
from src.
|
13 |
|
14 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
15 |
try:
|
@@ -136,7 +138,7 @@ class AbstractTranscription(ABC):
|
|
136 |
pprint(merged)
|
137 |
return merged
|
138 |
|
139 |
-
def transcribe(self, audio: str, whisperCallable:
|
140 |
progressListener: ProgressListener = None):
|
141 |
"""
|
142 |
Transcribe the given audo file.
|
|
|
5 |
from typing import Any, Deque, Iterator, List, Dict
|
6 |
|
7 |
from pprint import pprint
|
8 |
+
from src.hooks.progressListener import ProgressListener
|
9 |
+
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
10 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
11 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
12 |
|
13 |
from src.segments import merge_timestamps
|
14 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
|
15 |
|
16 |
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
17 |
try:
|
|
|
138 |
pprint(merged)
|
139 |
return merged
|
140 |
|
141 |
+
def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
|
142 |
progressListener: ProgressListener = None):
|
143 |
"""
|
144 |
Transcribe the given audo file.
|
src/vadParallel.py
CHANGED
@@ -2,15 +2,16 @@ import multiprocessing
|
|
2 |
from queue import Empty
|
3 |
import threading
|
4 |
import time
|
5 |
-
from src.hooks.
|
6 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
7 |
-
from src.whisperContainer import WhisperCallback
|
8 |
|
9 |
from multiprocessing import Pool, Queue
|
10 |
|
11 |
from typing import Any, Dict, List, Union
|
12 |
import os
|
13 |
|
|
|
|
|
14 |
class _ProgressListenerToQueue(ProgressListener):
|
15 |
def __init__(self, progress_queue: Queue):
|
16 |
self.progress_queue = progress_queue
|
@@ -104,7 +105,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
104 |
def __init__(self, sampling_rate: int = 16000):
|
105 |
super().__init__(sampling_rate=sampling_rate)
|
106 |
|
107 |
-
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable:
|
108 |
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
|
109 |
progress_listener: ProgressListener = None):
|
110 |
total_duration = get_audio_duration(audio)
|
@@ -276,7 +277,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
276 |
return config.override_timestamps
|
277 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
278 |
|
279 |
-
def transcribe(self, audio: str, whisperCallable:
|
280 |
progressListener: ProgressListener = None):
|
281 |
# Override device ID the first time
|
282 |
if (os.environ.get("INITIALIZED", None) is None):
|
|
|
2 |
from queue import Empty
|
3 |
import threading
|
4 |
import time
|
5 |
+
from src.hooks.progressListener import ProgressListener
|
6 |
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
|
|
7 |
|
8 |
from multiprocessing import Pool, Queue
|
9 |
|
10 |
from typing import Any, Dict, List, Union
|
11 |
import os
|
12 |
|
13 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
|
14 |
+
|
15 |
class _ProgressListenerToQueue(ProgressListener):
|
16 |
def __init__(self, progress_queue: Queue):
|
17 |
self.progress_queue = progress_queue
|
|
|
105 |
def __init__(self, sampling_rate: int = 16000):
|
106 |
super().__init__(sampling_rate=sampling_rate)
|
107 |
|
108 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
|
109 |
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
|
110 |
progress_listener: ProgressListener = None):
|
111 |
total_duration = get_audio_duration(audio)
|
|
|
277 |
return config.override_timestamps
|
278 |
return super().get_merged_timestamps(timestamps, config, total_duration)
|
279 |
|
280 |
+
def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig,
|
281 |
progressListener: ProgressListener = None):
|
282 |
# Override device ID the first time
|
283 |
if (os.environ.get("INITIALIZED", None) is None):
|
src/whisper/abstractWhisperContainer.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List
|
3 |
+
from src.config import ModelConfig
|
4 |
+
|
5 |
+
from src.hooks.progressListener import ProgressListener
|
6 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
7 |
+
|
8 |
+
class AbstractWhisperCallback:
|
9 |
+
@abc.abstractmethod
|
10 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
11 |
+
"""
|
12 |
+
Peform the transcription of the given audio file or data.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
17 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
18 |
+
segment_index: int
|
19 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
20 |
+
task: str
|
21 |
+
The task - either translate or transcribe.
|
22 |
+
progress_listener: ProgressListener
|
23 |
+
A callback to receive progress updates.
|
24 |
+
"""
|
25 |
+
raise NotImplementedError()
|
26 |
+
|
27 |
+
def _concat_prompt(self, prompt1, prompt2):
|
28 |
+
if (prompt1 is None):
|
29 |
+
return prompt2
|
30 |
+
elif (prompt2 is None):
|
31 |
+
return prompt1
|
32 |
+
else:
|
33 |
+
return prompt1 + " " + prompt2
|
34 |
+
|
35 |
+
class AbstractWhisperContainer:
|
36 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
37 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
38 |
+
self.model_name = model_name
|
39 |
+
self.device = device
|
40 |
+
self.download_root = download_root
|
41 |
+
self.cache = cache
|
42 |
+
|
43 |
+
# Will be created on demand
|
44 |
+
self.model = None
|
45 |
+
|
46 |
+
# List of known models
|
47 |
+
self.models = models
|
48 |
+
|
49 |
+
def get_model(self):
|
50 |
+
if self.model is None:
|
51 |
+
|
52 |
+
if (self.cache is None):
|
53 |
+
self.model = self._create_model()
|
54 |
+
else:
|
55 |
+
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
56 |
+
self.model = self.cache.get(model_key, self._create_model)
|
57 |
+
return self.model
|
58 |
+
|
59 |
+
@abc.abstractmethod
|
60 |
+
def _create_model(self):
|
61 |
+
raise NotImplementedError()
|
62 |
+
|
63 |
+
def ensure_downloaded(self):
|
64 |
+
pass
|
65 |
+
|
66 |
+
@abc.abstractmethod
|
67 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict) -> AbstractWhisperCallback:
|
68 |
+
"""
|
69 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
language: str
|
74 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
75 |
+
task: str
|
76 |
+
The task - either translate or transcribe.
|
77 |
+
initial_prompt: str
|
78 |
+
The initial prompt to use for the transcription.
|
79 |
+
decodeOptions: dict
|
80 |
+
Additional options to pass to the decoder. Must be pickleable.
|
81 |
+
|
82 |
+
Returns
|
83 |
+
-------
|
84 |
+
A WhisperCallback object.
|
85 |
+
"""
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
# This is required for multiprocessing
|
89 |
+
def __getstate__(self):
|
90 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }
|
91 |
+
|
92 |
+
def __setstate__(self, state):
|
93 |
+
self.model_name = state["model_name"]
|
94 |
+
self.device = state["device"]
|
95 |
+
self.download_root = state["download_root"]
|
96 |
+
self.models = state["models"]
|
97 |
+
self.model = None
|
98 |
+
# Depickled objects must use the global cache
|
99 |
+
self.cache = GLOBAL_MODEL_CACHE
|
src/whisper/fasterWhisperContainer.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from faster_whisper import WhisperModel, download_model
|
5 |
+
from src.config import ModelConfig
|
6 |
+
from src.hooks.progressListener import ProgressListener
|
7 |
+
from src.modelCache import ModelCache
|
8 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
9 |
+
|
10 |
+
class FasterWhisperContainer(AbstractWhisperContainer):
|
11 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
12 |
+
cache: ModelCache = None,
|
13 |
+
models: List[ModelConfig] = []):
|
14 |
+
super().__init__(model_name, device, download_root, cache, models)
|
15 |
+
|
16 |
+
def ensure_downloaded(self):
|
17 |
+
"""
|
18 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
19 |
+
passing the container to a subprocess.
|
20 |
+
"""
|
21 |
+
model_config = self._get_model_config()
|
22 |
+
|
23 |
+
if os.path.isdir(model_config.url):
|
24 |
+
model_config.path = model_config.url
|
25 |
+
else:
|
26 |
+
model_config.path = download_model(model_config.url, output_dir=self.download_root)
|
27 |
+
|
28 |
+
def _get_model_config(self) -> ModelConfig:
|
29 |
+
"""
|
30 |
+
Get the model configuration for the model.
|
31 |
+
"""
|
32 |
+
for model in self.models:
|
33 |
+
if model.name == self.model_name:
|
34 |
+
return model
|
35 |
+
return None
|
36 |
+
|
37 |
+
def _create_model(self):
|
38 |
+
print("Loading faster whisper model " + self.model_name)
|
39 |
+
model_config = self._get_model_config()
|
40 |
+
|
41 |
+
if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
|
42 |
+
raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
|
43 |
+
|
44 |
+
device = self.device
|
45 |
+
|
46 |
+
if (device is None):
|
47 |
+
device = "auto"
|
48 |
+
|
49 |
+
model = WhisperModel(model_config.url, device=device, compute_type="float16")
|
50 |
+
return model
|
51 |
+
|
52 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
53 |
+
"""
|
54 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
language: str
|
59 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
60 |
+
task: str
|
61 |
+
The task - either translate or transcribe.
|
62 |
+
initial_prompt: str
|
63 |
+
The initial prompt to use for the transcription.
|
64 |
+
decodeOptions: dict
|
65 |
+
Additional options to pass to the decoder. Must be pickleable.
|
66 |
+
|
67 |
+
Returns
|
68 |
+
-------
|
69 |
+
A WhisperCallback object.
|
70 |
+
"""
|
71 |
+
return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
72 |
+
|
73 |
+
class FasterWhisperCallback(AbstractWhisperCallback):
|
74 |
+
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
75 |
+
self.model_container = model_container
|
76 |
+
self.language = language
|
77 |
+
self.task = task
|
78 |
+
self.initial_prompt = initial_prompt
|
79 |
+
self.decodeOptions = decodeOptions
|
80 |
+
|
81 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
82 |
+
"""
|
83 |
+
Peform the transcription of the given audio file or data.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
88 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
89 |
+
segment_index: int
|
90 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
91 |
+
task: str
|
92 |
+
The task - either translate or transcribe.
|
93 |
+
progress_listener: ProgressListener
|
94 |
+
A callback to receive progress updates.
|
95 |
+
"""
|
96 |
+
model: WhisperModel = self.model_container.get_model()
|
97 |
+
language_code = self._lookup_language_code(self.language) if self.language else None
|
98 |
+
|
99 |
+
segments_generator, info = model.transcribe(audio, \
|
100 |
+
language=language_code if language_code else detected_language, task=self.task, \
|
101 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
102 |
+
**self.decodeOptions
|
103 |
+
)
|
104 |
+
|
105 |
+
segments = []
|
106 |
+
|
107 |
+
for segment in segments_generator:
|
108 |
+
segments.append(segment)
|
109 |
+
|
110 |
+
if progress_listener is not None:
|
111 |
+
progress_listener.on_progress(segment.end, info.duration)
|
112 |
+
|
113 |
+
text = " ".join([segment.text for segment in segments])
|
114 |
+
|
115 |
+
# Convert the segments to a format that is easier to serialize
|
116 |
+
whisper_segments = [{
|
117 |
+
"text": segment.text,
|
118 |
+
"start": segment.start,
|
119 |
+
"end": segment.end,
|
120 |
+
|
121 |
+
# Extra fields added by faster-whisper
|
122 |
+
"words": [{
|
123 |
+
"start": word.start,
|
124 |
+
"end": word.end,
|
125 |
+
"word": word.word,
|
126 |
+
"probability": word.probability
|
127 |
+
} for word in (segment.words if segment.words is not None else []) ]
|
128 |
+
} for segment in segments]
|
129 |
+
|
130 |
+
result = {
|
131 |
+
"segments": whisper_segments,
|
132 |
+
"text": text,
|
133 |
+
"language": info.language if info else None,
|
134 |
+
|
135 |
+
# Extra fields added by faster-whisper
|
136 |
+
"language_probability": info.language_probability if info else None,
|
137 |
+
"duration": info.duration if info else None
|
138 |
+
}
|
139 |
+
|
140 |
+
if progress_listener is not None:
|
141 |
+
progress_listener.on_finished()
|
142 |
+
return result
|
143 |
+
|
144 |
+
def _lookup_language_code(self, language: str):
|
145 |
+
lookup = {
|
146 |
+
"english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
|
147 |
+
"french": "fr", "japanese": "ja", "portuguese": "pt", "turkish": "tr", "polish": "pl", "catalan": "ca",
|
148 |
+
"dutch": "nl", "arabic": "ar", "swedish": "sv", "italian": "it", "indonesian": "id", "hindi": "hi",
|
149 |
+
"finnish": "fi", "vietnamese": "vi", "hebrew": "he", "ukrainian": "uk", "greek": "el", "malay": "ms",
|
150 |
+
"czech": "cs", "romanian": "ro", "danish": "da", "hungarian": "hu", "tamil": "ta", "norwegian": "no",
|
151 |
+
"thai": "th", "urdu": "ur", "croatian": "hr", "bulgarian": "bg", "lithuanian": "lt", "latin": "la",
|
152 |
+
"maori": "mi", "malayalam": "ml", "welsh": "cy", "slovak": "sk", "telugu": "te", "persian": "fa",
|
153 |
+
"latvian": "lv", "bengali": "bn", "serbian": "sr", "azerbaijani": "az", "slovenian": "sl",
|
154 |
+
"kannada": "kn", "estonian": "et", "macedonian": "mk", "breton": "br", "basque": "eu", "icelandic": "is",
|
155 |
+
"armenian": "hy", "nepali": "ne", "mongolian": "mn", "bosnian": "bs", "kazakh": "kk", "albanian": "sq",
|
156 |
+
"swahili": "sw", "galician": "gl", "marathi": "mr", "punjabi": "pa", "sinhala": "si", "khmer": "km",
|
157 |
+
"shona": "sn", "yoruba": "yo", "somali": "so", "afrikaans": "af", "occitan": "oc", "georgian": "ka",
|
158 |
+
"belarusian": "be", "tajik": "tg", "sindhi": "sd", "gujarati": "gu", "amharic": "am", "yiddish": "yi",
|
159 |
+
"lao": "lo", "uzbek": "uz", "faroese": "fo", "haitian creole": "ht", "pashto": "ps", "turkmen": "tk",
|
160 |
+
"nynorsk": "nn", "maltese": "mt", "sanskrit": "sa", "luxembourgish": "lb", "myanmar": "my", "tibetan": "bo",
|
161 |
+
"tagalog": "tl", "malagasy": "mg", "assamese": "as", "tatar": "tt", "hawaiian": "haw", "lingala": "ln",
|
162 |
+
"hausa": "ha", "bashkir": "ba", "javanese": "jv", "sundanese": "su"
|
163 |
+
}
|
164 |
+
|
165 |
+
return lookup.get(language.lower() if language is not None else None, language)
|
src/{whisperContainer.py → whisper/whisperContainer.py}
RENAMED
@@ -1,40 +1,27 @@
|
|
1 |
# External programs
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
from typing import List
|
|
|
|
|
|
|
5 |
|
6 |
import whisper
|
7 |
from whisper import Whisper
|
8 |
|
9 |
from src.config import ModelConfig
|
10 |
-
from src.hooks.whisperProgressHook import
|
11 |
|
12 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
|
|
13 |
|
14 |
-
class WhisperContainer:
|
15 |
-
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
16 |
-
|
17 |
-
|
18 |
-
self.device = device
|
19 |
-
self.download_root = download_root
|
20 |
-
self.cache = cache
|
21 |
-
|
22 |
-
# Will be created on demand
|
23 |
-
self.model = None
|
24 |
-
|
25 |
-
# List of known models
|
26 |
-
self.models = models
|
27 |
|
28 |
-
def get_model(self):
|
29 |
-
if self.model is None:
|
30 |
-
|
31 |
-
if (self.cache is None):
|
32 |
-
self.model = self._create_model()
|
33 |
-
else:
|
34 |
-
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
35 |
-
self.model = self.cache.get(model_key, self._create_model)
|
36 |
-
return self.model
|
37 |
-
|
38 |
def ensure_downloaded(self):
|
39 |
"""
|
40 |
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
@@ -43,7 +30,7 @@ class WhisperContainer:
|
|
43 |
# Warning: Using private API here
|
44 |
try:
|
45 |
root_dir = self.download_root
|
46 |
-
model_config = self.
|
47 |
|
48 |
if root_dir is None:
|
49 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
@@ -60,7 +47,7 @@ class WhisperContainer:
|
|
60 |
print("Error pre-downloading model: " + str(e))
|
61 |
return False
|
62 |
|
63 |
-
def
|
64 |
"""
|
65 |
Get the model configuration for the model.
|
66 |
"""
|
@@ -71,10 +58,10 @@ class WhisperContainer:
|
|
71 |
|
72 |
def _create_model(self):
|
73 |
print("Loading whisper model " + self.model_name)
|
74 |
-
|
75 |
-
|
76 |
# Note that the model will not be downloaded in the case of an official Whisper model
|
77 |
-
model_path =
|
78 |
|
79 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
80 |
|
@@ -99,21 +86,73 @@ class WhisperContainer:
|
|
99 |
"""
|
100 |
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
self.device = state["device"]
|
109 |
-
self.download_root = state["download_root"]
|
110 |
-
self.models = state["models"]
|
111 |
-
self.model = None
|
112 |
-
# Depickled objects must use the global cache
|
113 |
-
self.cache = GLOBAL_MODEL_CACHE
|
114 |
|
|
|
115 |
|
116 |
-
class WhisperCallback:
|
117 |
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
118 |
self.model_container = model_container
|
119 |
self.language = language
|
@@ -133,14 +172,8 @@ class WhisperCallback:
|
|
133 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
134 |
task: str
|
135 |
The task - either translate or transcribe.
|
136 |
-
|
137 |
-
|
138 |
-
detected_language: str
|
139 |
-
The detected language of the audio file.
|
140 |
-
|
141 |
-
Returns
|
142 |
-
-------
|
143 |
-
The result of the Whisper call.
|
144 |
"""
|
145 |
model = self.model_container.get_model()
|
146 |
|
@@ -155,12 +188,4 @@ class WhisperCallback:
|
|
155 |
language=self.language if self.language else detected_language, task=self.task, \
|
156 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
157 |
**self.decodeOptions
|
158 |
-
)
|
159 |
-
|
160 |
-
def _concat_prompt(self, prompt1, prompt2):
|
161 |
-
if (prompt1 is None):
|
162 |
-
return prompt2
|
163 |
-
elif (prompt2 is None):
|
164 |
-
return prompt1
|
165 |
-
else:
|
166 |
-
return prompt1 + " " + prompt2
|
|
|
1 |
# External programs
|
2 |
+
import abc
|
3 |
import os
|
4 |
import sys
|
5 |
from typing import List
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
import urllib3
|
8 |
+
from src.hooks.progressListener import ProgressListener
|
9 |
|
10 |
import whisper
|
11 |
from whisper import Whisper
|
12 |
|
13 |
from src.config import ModelConfig
|
14 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
15 |
|
16 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
17 |
+
from src.utils import download_file
|
18 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
19 |
|
20 |
+
class WhisperContainer(AbstractWhisperContainer):
|
21 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None,
|
22 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
23 |
+
super().__init__(model_name, device, download_root, cache, models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def ensure_downloaded(self):
|
26 |
"""
|
27 |
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
|
|
30 |
# Warning: Using private API here
|
31 |
try:
|
32 |
root_dir = self.download_root
|
33 |
+
model_config = self._get_model_config()
|
34 |
|
35 |
if root_dir is None:
|
36 |
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
|
|
47 |
print("Error pre-downloading model: " + str(e))
|
48 |
return False
|
49 |
|
50 |
+
def _get_model_config(self) -> ModelConfig:
|
51 |
"""
|
52 |
Get the model configuration for the model.
|
53 |
"""
|
|
|
58 |
|
59 |
def _create_model(self):
|
60 |
print("Loading whisper model " + self.model_name)
|
61 |
+
model_config = self._get_model_config()
|
62 |
+
|
63 |
# Note that the model will not be downloaded in the case of an official Whisper model
|
64 |
+
model_path = self._get_model_path(model_config, self.download_root)
|
65 |
|
66 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
67 |
|
|
|
86 |
"""
|
87 |
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
88 |
|
89 |
+
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
90 |
+
from src.conversion.hf_converter import convert_hf_whisper
|
91 |
+
"""
|
92 |
+
Download the model.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
model_config: ModelConfig
|
97 |
+
The model configuration.
|
98 |
+
"""
|
99 |
+
# See if path is already set
|
100 |
+
if model_config.path is not None:
|
101 |
+
return model_config.path
|
102 |
+
|
103 |
+
if root_dir is None:
|
104 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
105 |
+
|
106 |
+
model_type = model_config.type.lower() if model_config.type is not None else "whisper"
|
107 |
+
|
108 |
+
if model_type in ["huggingface", "hf"]:
|
109 |
+
model_config.path = model_config.url
|
110 |
+
destination_target = os.path.join(root_dir, model_config.name + ".pt")
|
111 |
+
|
112 |
+
# Convert from HuggingFace format to Whisper format
|
113 |
+
if os.path.exists(destination_target):
|
114 |
+
print(f"File {destination_target} already exists, skipping conversion")
|
115 |
+
else:
|
116 |
+
print("Saving HuggingFace model in Whisper format to " + destination_target)
|
117 |
+
convert_hf_whisper(model_config.url, destination_target)
|
118 |
+
|
119 |
+
model_config.path = destination_target
|
120 |
+
|
121 |
+
elif model_type in ["whisper", "w"]:
|
122 |
+
model_config.path = model_config.url
|
123 |
+
|
124 |
+
# See if URL is just a file
|
125 |
+
if model_config.url in whisper._MODELS:
|
126 |
+
# No need to download anything - Whisper will handle it
|
127 |
+
model_config.path = model_config.url
|
128 |
+
elif model_config.url.startswith("file://"):
|
129 |
+
# Get file path
|
130 |
+
model_config.path = urlparse(model_config.url).path
|
131 |
+
# See if it is an URL
|
132 |
+
elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
|
133 |
+
# Extension (or file name)
|
134 |
+
extension = os.path.splitext(model_config.url)[-1]
|
135 |
+
download_target = os.path.join(root_dir, model_config.name + extension)
|
136 |
+
|
137 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
138 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
139 |
+
|
140 |
+
if not os.path.isfile(download_target):
|
141 |
+
download_file(model_config.url, download_target)
|
142 |
+
else:
|
143 |
+
print(f"File {download_target} already exists, skipping download")
|
144 |
+
|
145 |
+
model_config.path = download_target
|
146 |
+
# Must be a local file
|
147 |
+
else:
|
148 |
+
model_config.path = model_config.url
|
149 |
|
150 |
+
else:
|
151 |
+
raise ValueError(f"Unknown model type {model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
return model_config.path
|
154 |
|
155 |
+
class WhisperCallback(AbstractWhisperCallback):
|
156 |
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
157 |
self.model_container = model_container
|
158 |
self.language = language
|
|
|
172 |
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
173 |
task: str
|
174 |
The task - either translate or transcribe.
|
175 |
+
progress_listener: ProgressListener
|
176 |
+
A callback to receive progress updates.
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
"""
|
178 |
model = self.model_container.get_model()
|
179 |
|
|
|
188 |
language=self.language if self.language else detected_language, task=self.task, \
|
189 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
190 |
**self.decodeOptions
|
191 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/whisper/whisperFactory.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from src import modelCache
|
3 |
+
from src.config import ModelConfig
|
4 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
5 |
+
|
6 |
+
def create_whisper_container(whisper_implementation: str,
|
7 |
+
model_name: str, device: str = None, download_root: str = None,
|
8 |
+
cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
|
9 |
+
if (whisper_implementation == "whisper"):
|
10 |
+
from src.whisper.whisperContainer import WhisperContainer
|
11 |
+
return WhisperContainer(model_name, device, download_root, cache, models)
|
12 |
+
elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
|
13 |
+
from src.whisper.fasterWhisperContainer import FasterWhisperContainer
|
14 |
+
return FasterWhisperContainer(model_name, device, download_root, cache, models)
|
15 |
+
else:
|
16 |
+
raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
|