|
from time import perf_counter |
|
from transformers import ( |
|
AutoProcessor, |
|
AutoModelForSpeechSeq2Seq |
|
) |
|
import torch |
|
from copy import copy |
|
from baseHandler import BaseHandler |
|
from rich.console import Console |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
console = Console() |
|
|
|
SUPPORTED_LANGUAGES = [ |
|
"en", |
|
"fr", |
|
"es", |
|
"zh", |
|
"ja", |
|
"ko", |
|
] |
|
|
|
|
|
class WhisperSTTHandler(BaseHandler): |
|
""" |
|
Handles the Speech To Text generation using a Whisper model. |
|
""" |
|
|
|
def setup( |
|
self, |
|
model_name="distil-whisper/distil-large-v3", |
|
device="cuda", |
|
torch_dtype="float16", |
|
compile_mode=None, |
|
language=None, |
|
gen_kwargs={}, |
|
): |
|
self.device = device |
|
self.torch_dtype = getattr(torch, torch_dtype) |
|
self.compile_mode = compile_mode |
|
self.gen_kwargs = gen_kwargs |
|
if language == 'auto': |
|
language = None |
|
self.last_language = language |
|
if self.last_language is not None: |
|
self.gen_kwargs["language"] = self.last_language |
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name) |
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_name, |
|
torch_dtype=self.torch_dtype, |
|
).to(device) |
|
|
|
|
|
if self.compile_mode: |
|
self.model.generation_config.cache_implementation = "static" |
|
self.model.forward = torch.compile( |
|
self.model.forward, mode=self.compile_mode, fullgraph=True |
|
) |
|
self.warmup() |
|
|
|
def prepare_model_inputs(self, spoken_prompt): |
|
input_features = self.processor( |
|
spoken_prompt, sampling_rate=16000, return_tensors="pt" |
|
).input_features |
|
input_features = input_features.to(self.device, dtype=self.torch_dtype) |
|
|
|
return input_features |
|
|
|
def warmup(self): |
|
logger.info(f"Warming up {self.__class__.__name__}") |
|
|
|
|
|
n_steps = 1 if self.compile_mode == "default" else 2 |
|
dummy_input = torch.randn( |
|
(1, self.model.config.num_mel_bins, 3000), |
|
dtype=self.torch_dtype, |
|
device=self.device, |
|
) |
|
if self.compile_mode not in (None, "default"): |
|
|
|
|
|
|
|
warmup_gen_kwargs = { |
|
"min_new_tokens": self.gen_kwargs[ |
|
"max_new_tokens" |
|
], |
|
"max_new_tokens": self.gen_kwargs["max_new_tokens"], |
|
**self.gen_kwargs, |
|
} |
|
else: |
|
warmup_gen_kwargs = self.gen_kwargs |
|
|
|
if self.device == "cuda": |
|
start_event = torch.cuda.Event(enable_timing=True) |
|
end_event = torch.cuda.Event(enable_timing=True) |
|
torch.cuda.synchronize() |
|
start_event.record() |
|
|
|
for _ in range(n_steps): |
|
_ = self.model.generate(dummy_input, **warmup_gen_kwargs) |
|
|
|
if self.device == "cuda": |
|
end_event.record() |
|
torch.cuda.synchronize() |
|
|
|
logger.info( |
|
f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" |
|
) |
|
|
|
def process(self, spoken_prompt): |
|
logger.debug("infering whisper...") |
|
console.print("infering whisper...") |
|
|
|
global pipeline_start |
|
pipeline_start = perf_counter() |
|
|
|
input_features = self.prepare_model_inputs(spoken_prompt) |
|
pred_ids = self.model.generate(input_features, **self.gen_kwargs) |
|
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] |
|
|
|
if language_code not in SUPPORTED_LANGUAGES: |
|
logger.warning("Whisper detected unsupported language:", language_code) |
|
console.print("Whisper detected unsupported language:", language_code) |
|
gen_kwargs = copy(self.gen_kwargs) |
|
gen_kwargs['language'] = self.last_language |
|
language_code = self.last_language |
|
pred_ids = self.model.generate(input_features, **gen_kwargs) |
|
else: |
|
self.last_language = language_code |
|
|
|
pred_text = self.processor.batch_decode( |
|
pred_ids, skip_special_tokens=True, decode_with_timestamps=False |
|
)[0] |
|
language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] |
|
|
|
logger.debug("finished whisper inference") |
|
console.print(f"[yellow]USER: {pred_text}") |
|
console.print(f"Language Code Whisper: {language_code}") |
|
logger.debug(f"Language Code Whisper: {language_code}") |
|
|
|
yield (pred_text, language_code) |
|
|