Support CLI into faster-whisper
Browse files- app.py +4 -1
- cli.py +6 -2
- config.json5 +3 -1
- src/config.py +2 -3
- src/whisper/abstractWhisperContainer.py +12 -3
- src/whisper/fasterWhisperContainer.py +41 -8
- src/whisper/whisperContainer.py +14 -4
- src/whisper/whisperFactory.py +4 -3
app.py
CHANGED
@@ -126,7 +126,8 @@ class WhisperTranscriber:
|
|
126 |
selectedModel = modelName if modelName is not None else "base"
|
127 |
|
128 |
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
129 |
-
model_name=selectedModel,
|
|
|
130 |
|
131 |
# Result
|
132 |
download = []
|
@@ -518,6 +519,8 @@ if __name__ == '__main__':
|
|
518 |
help="directory to save the outputs")
|
519 |
parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
|
520 |
help="the Whisper implementation to use")
|
|
|
|
|
521 |
|
522 |
args = parser.parse_args().__dict__
|
523 |
|
|
|
126 |
selectedModel = modelName if modelName is not None else "base"
|
127 |
|
128 |
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
129 |
+
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
130 |
+
cache=self.model_cache, models=self.app_config.models)
|
131 |
|
132 |
# Result
|
133 |
download = []
|
|
|
519 |
help="directory to save the outputs")
|
520 |
parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
|
521 |
help="the Whisper implementation to use")
|
522 |
+
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
|
523 |
+
help="the compute type to use for inference")
|
524 |
|
525 |
args = parser.parse_args().__dict__
|
526 |
|
cli.py
CHANGED
@@ -80,6 +80,8 @@ def cli():
|
|
80 |
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
81 |
parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
|
82 |
help="whether to perform inference in fp16; True by default")
|
|
|
|
|
83 |
|
84 |
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
|
85 |
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
@@ -119,12 +121,14 @@ def cli():
|
|
119 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
120 |
auto_parallel = args.pop("auto_parallel")
|
121 |
|
|
|
|
|
122 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
123 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
124 |
transcriber.set_auto_parallel(auto_parallel)
|
125 |
|
126 |
-
model = create_whisper_container(whisper_implementation=whisper_implementation,
|
127 |
-
device=device, download_root=model_dir, models=app_config.models)
|
128 |
|
129 |
if (transcriber._has_parallel_devices()):
|
130 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
|
|
80 |
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
81 |
parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
|
82 |
help="whether to perform inference in fp16; True by default")
|
83 |
+
parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
|
84 |
+
help="the compute type to use for inference")
|
85 |
|
86 |
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
|
87 |
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
|
|
121 |
vad_cpu_cores = args.pop("vad_cpu_cores")
|
122 |
auto_parallel = args.pop("auto_parallel")
|
123 |
|
124 |
+
compute_type = args.pop("compute_type")
|
125 |
+
|
126 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
127 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
128 |
transcriber.set_auto_parallel(auto_parallel)
|
129 |
|
130 |
+
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
131 |
+
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
132 |
|
133 |
if (transcriber._has_parallel_devices()):
|
134 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
config.json5
CHANGED
@@ -104,7 +104,7 @@
|
|
104 |
// Number of beams in beam search, only applicable when temperature is zero
|
105 |
"beam_size": 5,
|
106 |
// Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
|
107 |
-
"patience":
|
108 |
// Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
|
109 |
"length_penalty": null,
|
110 |
// Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
|
@@ -115,6 +115,8 @@
|
|
115 |
"condition_on_previous_text": true,
|
116 |
// Whether to perform inference in fp16; True by default
|
117 |
"fp16": true,
|
|
|
|
|
118 |
// Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
|
119 |
"temperature_increment_on_fallback": 0.2,
|
120 |
// If the gzip compression ratio is higher than this value, treat the decoding as failed
|
|
|
104 |
// Number of beams in beam search, only applicable when temperature is zero
|
105 |
"beam_size": 5,
|
106 |
// Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
|
107 |
+
"patience": 1,
|
108 |
// Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
|
109 |
"length_penalty": null,
|
110 |
// Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
|
|
|
115 |
"condition_on_previous_text": true,
|
116 |
// Whether to perform inference in fp16; True by default
|
117 |
"fp16": true,
|
118 |
+
// The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
|
119 |
+
"compute_type": "float16",
|
120 |
// Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
|
121 |
"temperature_increment_on_fallback": 0.2,
|
122 |
// If the gzip compression ratio is higher than this value, treat the decoding as failed
|
src/config.py
CHANGED
@@ -39,12 +39,10 @@ class ApplicationConfig:
|
|
39 |
patience: float = None, length_penalty: float = None,
|
40 |
suppress_tokens: str = "-1", initial_prompt: str = None,
|
41 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
|
|
42 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
43 |
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
|
44 |
|
45 |
-
if device is None:
|
46 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
-
|
48 |
self.models = models
|
49 |
|
50 |
# WebUI settings
|
@@ -82,6 +80,7 @@ class ApplicationConfig:
|
|
82 |
self.initial_prompt = initial_prompt
|
83 |
self.condition_on_previous_text = condition_on_previous_text
|
84 |
self.fp16 = fp16
|
|
|
85 |
self.temperature_increment_on_fallback = temperature_increment_on_fallback
|
86 |
self.compression_ratio_threshold = compression_ratio_threshold
|
87 |
self.logprob_threshold = logprob_threshold
|
|
|
39 |
patience: float = None, length_penalty: float = None,
|
40 |
suppress_tokens: str = "-1", initial_prompt: str = None,
|
41 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
42 |
+
compute_type: str = "float16",
|
43 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
44 |
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
|
45 |
|
|
|
|
|
|
|
46 |
self.models = models
|
47 |
|
48 |
# WebUI settings
|
|
|
80 |
self.initial_prompt = initial_prompt
|
81 |
self.condition_on_previous_text = condition_on_previous_text
|
82 |
self.fp16 = fp16
|
83 |
+
self.compute_type = compute_type
|
84 |
self.temperature_increment_on_fallback = temperature_increment_on_fallback
|
85 |
self.compression_ratio_threshold = compression_ratio_threshold
|
86 |
self.logprob_threshold = logprob_threshold
|
src/whisper/abstractWhisperContainer.py
CHANGED
@@ -33,10 +33,12 @@ class AbstractWhisperCallback:
|
|
33 |
return prompt1 + " " + prompt2
|
34 |
|
35 |
class AbstractWhisperContainer:
|
36 |
-
def __init__(self, model_name: str, device: str = None,
|
37 |
-
|
|
|
38 |
self.model_name = model_name
|
39 |
self.device = device
|
|
|
40 |
self.download_root = download_root
|
41 |
self.cache = cache
|
42 |
|
@@ -87,13 +89,20 @@ class AbstractWhisperContainer:
|
|
87 |
|
88 |
# This is required for multiprocessing
|
89 |
def __getstate__(self):
|
90 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
def __setstate__(self, state):
|
93 |
self.model_name = state["model_name"]
|
94 |
self.device = state["device"]
|
95 |
self.download_root = state["download_root"]
|
96 |
self.models = state["models"]
|
|
|
97 |
self.model = None
|
98 |
# Depickled objects must use the global cache
|
99 |
self.cache = GLOBAL_MODEL_CACHE
|
|
|
33 |
return prompt1 + " " + prompt2
|
34 |
|
35 |
class AbstractWhisperContainer:
|
36 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
37 |
+
download_root: str = None,
|
38 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
39 |
self.model_name = model_name
|
40 |
self.device = device
|
41 |
+
self.compute_type = compute_type
|
42 |
self.download_root = download_root
|
43 |
self.cache = cache
|
44 |
|
|
|
89 |
|
90 |
# This is required for multiprocessing
|
91 |
def __getstate__(self):
|
92 |
+
return {
|
93 |
+
"model_name": self.model_name,
|
94 |
+
"device": self.device,
|
95 |
+
"download_root": self.download_root,
|
96 |
+
"models": self.models,
|
97 |
+
"compute_type": self.compute_type
|
98 |
+
}
|
99 |
|
100 |
def __setstate__(self, state):
|
101 |
self.model_name = state["model_name"]
|
102 |
self.device = state["device"]
|
103 |
self.download_root = state["download_root"]
|
104 |
self.models = state["models"]
|
105 |
+
self.compute_type = state["compute_type"]
|
106 |
self.model = None
|
107 |
# Depickled objects must use the global cache
|
108 |
self.cache = GLOBAL_MODEL_CACHE
|
src/whisper/fasterWhisperContainer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import List
|
3 |
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
from src.config import ModelConfig
|
@@ -8,10 +8,10 @@ from src.modelCache import ModelCache
|
|
8 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
9 |
|
10 |
class FasterWhisperContainer(AbstractWhisperContainer):
|
11 |
-
def __init__(self, model_name: str, device: str = None,
|
12 |
-
|
13 |
-
models: List[ModelConfig] = []):
|
14 |
-
super().__init__(model_name, device, download_root, cache, models)
|
15 |
|
16 |
def ensure_downloaded(self):
|
17 |
"""
|
@@ -35,7 +35,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
35 |
return None
|
36 |
|
37 |
def _create_model(self):
|
38 |
-
print("Loading faster whisper model " + self.model_name)
|
39 |
model_config = self._get_model_config()
|
40 |
|
41 |
if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
|
@@ -46,7 +46,7 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
46 |
if (device is None):
|
47 |
device = "auto"
|
48 |
|
49 |
-
model = WhisperModel(model_config.url, device=device, compute_type=
|
50 |
return model
|
51 |
|
52 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
@@ -96,10 +96,33 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
96 |
model: WhisperModel = self.model_container.get_model()
|
97 |
language_code = self._lookup_language_code(self.language) if self.language else None
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
segments_generator, info = model.transcribe(audio, \
|
100 |
language=language_code if language_code else detected_language, task=self.task, \
|
101 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
102 |
-
**
|
103 |
)
|
104 |
|
105 |
segments = []
|
@@ -109,6 +132,8 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
109 |
|
110 |
if progress_listener is not None:
|
111 |
progress_listener.on_progress(segment.end, info.duration)
|
|
|
|
|
112 |
|
113 |
text = " ".join([segment.text for segment in segments])
|
114 |
|
@@ -141,6 +166,14 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
141 |
progress_listener.on_finished()
|
142 |
return result
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
def _lookup_language_code(self, language: str):
|
145 |
lookup = {
|
146 |
"english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
|
|
|
1 |
import os
|
2 |
+
from typing import List, Union
|
3 |
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
from src.config import ModelConfig
|
|
|
8 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
9 |
|
10 |
class FasterWhisperContainer(AbstractWhisperContainer):
|
11 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
12 |
+
download_root: str = None,
|
13 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
14 |
+
super().__init__(model_name, device, compute_type, download_root, cache, models)
|
15 |
|
16 |
def ensure_downloaded(self):
|
17 |
"""
|
|
|
35 |
return None
|
36 |
|
37 |
def _create_model(self):
|
38 |
+
print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
|
39 |
model_config = self._get_model_config()
|
40 |
|
41 |
if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
|
|
|
46 |
if (device is None):
|
47 |
device = "auto"
|
48 |
|
49 |
+
model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
|
50 |
return model
|
51 |
|
52 |
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
|
|
96 |
model: WhisperModel = self.model_container.get_model()
|
97 |
language_code = self._lookup_language_code(self.language) if self.language else None
|
98 |
|
99 |
+
# Copy decode options and remove options that are not supported by faster-whisper
|
100 |
+
decodeOptions = self.decodeOptions.copy()
|
101 |
+
verbose = decodeOptions.pop("verbose", None)
|
102 |
+
|
103 |
+
logprob_threshold = decodeOptions.pop("logprob_threshold", None)
|
104 |
+
|
105 |
+
patience = decodeOptions.pop("patience", None)
|
106 |
+
length_penalty = decodeOptions.pop("length_penalty", None)
|
107 |
+
suppress_tokens = decodeOptions.pop("suppress_tokens", None)
|
108 |
+
|
109 |
+
if (decodeOptions.pop("fp16", None) is not None):
|
110 |
+
print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
|
111 |
+
|
112 |
+
# Fix up decode options
|
113 |
+
if (logprob_threshold is not None):
|
114 |
+
decodeOptions["log_prob_threshold"] = logprob_threshold
|
115 |
+
|
116 |
+
decodeOptions["patience"] = float(patience) if patience is not None else 1.0
|
117 |
+
decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
|
118 |
+
|
119 |
+
# See if supress_tokens is a string - if so, convert it to a list of ints
|
120 |
+
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
121 |
+
|
122 |
segments_generator, info = model.transcribe(audio, \
|
123 |
language=language_code if language_code else detected_language, task=self.task, \
|
124 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
125 |
+
**decodeOptions
|
126 |
)
|
127 |
|
128 |
segments = []
|
|
|
132 |
|
133 |
if progress_listener is not None:
|
134 |
progress_listener.on_progress(segment.end, info.duration)
|
135 |
+
if verbose:
|
136 |
+
print(segment.text)
|
137 |
|
138 |
text = " ".join([segment.text for segment in segments])
|
139 |
|
|
|
166 |
progress_listener.on_finished()
|
167 |
return result
|
168 |
|
169 |
+
def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
|
170 |
+
if (suppress_tokens is None):
|
171 |
+
return None
|
172 |
+
if (isinstance(suppress_tokens, list)):
|
173 |
+
return suppress_tokens
|
174 |
+
|
175 |
+
return [int(token) for token in suppress_tokens.split(",")]
|
176 |
+
|
177 |
def _lookup_language_code(self, language: str):
|
178 |
lookup = {
|
179 |
"english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
|
src/whisper/whisperContainer.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
import sys
|
5 |
from typing import List
|
6 |
from urllib.parse import urlparse
|
|
|
7 |
import urllib3
|
8 |
from src.hooks.progressListener import ProgressListener
|
9 |
|
@@ -18,9 +19,12 @@ from src.utils import download_file
|
|
18 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
19 |
|
20 |
class WhisperContainer(AbstractWhisperContainer):
|
21 |
-
def __init__(self, model_name: str, device: str = None,
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
def ensure_downloaded(self):
|
26 |
"""
|
@@ -184,8 +188,14 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
184 |
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
185 |
|
186 |
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
return model.transcribe(audio, \
|
188 |
language=self.language if self.language else detected_language, task=self.task, \
|
189 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
190 |
-
**
|
191 |
)
|
|
|
4 |
import sys
|
5 |
from typing import List
|
6 |
from urllib.parse import urlparse
|
7 |
+
import torch
|
8 |
import urllib3
|
9 |
from src.hooks.progressListener import ProgressListener
|
10 |
|
|
|
19 |
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
20 |
|
21 |
class WhisperContainer(AbstractWhisperContainer):
|
22 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
23 |
+
download_root: str = None,
|
24 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
25 |
+
if device is None:
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
super().__init__(model_name, device, compute_type, download_root, cache, models)
|
28 |
|
29 |
def ensure_downloaded(self):
|
30 |
"""
|
|
|
188 |
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
189 |
|
190 |
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
|
191 |
+
decodeOptions = self.decodeOptions.copy()
|
192 |
+
|
193 |
+
# Add fp16
|
194 |
+
if self.model_container.compute_type in ["fp16", "float16"]:
|
195 |
+
decodeOptions["fp16"] = True
|
196 |
+
|
197 |
return model.transcribe(audio, \
|
198 |
language=self.language if self.language else detected_language, task=self.task, \
|
199 |
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
200 |
+
**decodeOptions
|
201 |
)
|
src/whisper/whisperFactory.py
CHANGED
@@ -4,15 +4,16 @@ from src.config import ModelConfig
|
|
4 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
5 |
|
6 |
def create_whisper_container(whisper_implementation: str,
|
7 |
-
model_name: str, device: str = None,
|
|
|
8 |
cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
|
9 |
print("Creating whisper container for " + whisper_implementation)
|
10 |
|
11 |
if (whisper_implementation == "whisper"):
|
12 |
from src.whisper.whisperContainer import WhisperContainer
|
13 |
-
return WhisperContainer(model_name, device, download_root, cache, models)
|
14 |
elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
|
15 |
from src.whisper.fasterWhisperContainer import FasterWhisperContainer
|
16 |
-
return FasterWhisperContainer(model_name, device, download_root, cache, models)
|
17 |
else:
|
18 |
raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
|
|
|
4 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
5 |
|
6 |
def create_whisper_container(whisper_implementation: str,
|
7 |
+
model_name: str, device: str = None, compute_type: str = "float16",
|
8 |
+
download_root: str = None,
|
9 |
cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
|
10 |
print("Creating whisper container for " + whisper_implementation)
|
11 |
|
12 |
if (whisper_implementation == "whisper"):
|
13 |
from src.whisper.whisperContainer import WhisperContainer
|
14 |
+
return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
|
15 |
elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
|
16 |
from src.whisper.fasterWhisperContainer import FasterWhisperContainer
|
17 |
+
return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
|
18 |
else:
|
19 |
raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
|