Slower-whisper / src /whisper /fasterWhisperContainer.py
aadnk's picture
Refactor language list
adca588
raw
history blame
7.78 kB
import os
from typing import List, Union
from faster_whisper import WhisperModel, download_model
from src.config import ModelConfig
from src.hooks.progressListener import ProgressListener
from src.languages import get_language_from_name
from src.modelCache import ModelCache
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
class FasterWhisperContainer(AbstractWhisperContainer):
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
download_root: str = None,
cache: ModelCache = None, models: List[ModelConfig] = []):
super().__init__(model_name, device, compute_type, download_root, cache, models)
def ensure_downloaded(self):
"""
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
passing the container to a subprocess.
"""
model_config = self._get_model_config()
if os.path.isdir(model_config.url):
model_config.path = model_config.url
else:
model_config.path = download_model(model_config.url, output_dir=self.download_root)
def _get_model_config(self) -> ModelConfig:
"""
Get the model configuration for the model.
"""
for model in self.models:
if model.name == self.model_name:
return model
return None
def _create_model(self):
print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
model_config = self._get_model_config()
if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
device = self.device
if (device is None):
device = "auto"
model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
return model
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
"""
Create a WhisperCallback object that can be used to transcript audio files.
Parameters
----------
language: str
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
initial_prompt: str
The initial prompt to use for the transcription.
decodeOptions: dict
Additional options to pass to the decoder. Must be pickleable.
Returns
-------
A WhisperCallback object.
"""
return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
class FasterWhisperCallback(AbstractWhisperCallback):
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
self.model_container = model_container
self.language = language
self.task = task
self.initial_prompt = initial_prompt
self.decodeOptions = decodeOptions
self._printed_warning = False
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
"""
Peform the transcription of the given audio file or data.
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor]
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
segment_index: int
The target language of the transcription. If not specified, the language will be inferred from the audio content.
task: str
The task - either translate or transcribe.
progress_listener: ProgressListener
A callback to receive progress updates.
"""
model: WhisperModel = self.model_container.get_model()
language_code = self._lookup_language_code(self.language) if self.language else None
# Copy decode options and remove options that are not supported by faster-whisper
decodeOptions = self.decodeOptions.copy()
verbose = decodeOptions.pop("verbose", None)
logprob_threshold = decodeOptions.pop("logprob_threshold", None)
patience = decodeOptions.pop("patience", None)
length_penalty = decodeOptions.pop("length_penalty", None)
suppress_tokens = decodeOptions.pop("suppress_tokens", None)
if (decodeOptions.pop("fp16", None) is not None):
if not self._printed_warning:
print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
self._printed_warning = True
# Fix up decode options
if (logprob_threshold is not None):
decodeOptions["log_prob_threshold"] = logprob_threshold
decodeOptions["patience"] = float(patience) if patience is not None else 1.0
decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
# See if supress_tokens is a string - if so, convert it to a list of ints
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
segments_generator, info = model.transcribe(audio, \
language=language_code if language_code else detected_language, task=self.task, \
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
**decodeOptions
)
segments = []
for segment in segments_generator:
segments.append(segment)
if progress_listener is not None:
progress_listener.on_progress(segment.end, info.duration)
if verbose:
print(segment.text)
text = " ".join([segment.text for segment in segments])
# Convert the segments to a format that is easier to serialize
whisper_segments = [{
"text": segment.text,
"start": segment.start,
"end": segment.end,
# Extra fields added by faster-whisper
"words": [{
"start": word.start,
"end": word.end,
"word": word.word,
"probability": word.probability
} for word in (segment.words if segment.words is not None else []) ]
} for segment in segments]
result = {
"segments": whisper_segments,
"text": text,
"language": info.language if info else None,
# Extra fields added by faster-whisper
"language_probability": info.language_probability if info else None,
"duration": info.duration if info else None
}
if progress_listener is not None:
progress_listener.on_finished()
return result
def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
if (suppress_tokens is None):
return None
if (isinstance(suppress_tokens, list)):
return suppress_tokens
return [int(token) for token in suppress_tokens.split(",")]
def _lookup_language_code(self, language: str):
language = get_language_from_name(language)
if language is None:
raise ValueError("Invalid language: " + language)
return language.code