from time import perf_counter from transformers import ( AutoProcessor, AutoModelForSpeechSeq2Seq ) import torch from copy import copy from baseHandler import BaseHandler from rich.console import Console import logging logger = logging.getLogger(__name__) console = Console() SUPPORTED_LANGUAGES = [ "en", "fr", "es", "zh", "ja", "ko", ] class WhisperSTTHandler(BaseHandler): """ Handles the Speech To Text generation using a Whisper model. """ def setup( self, model_name="distil-whisper/distil-large-v3", device="cuda", torch_dtype="float16", compile_mode=None, language=None, gen_kwargs={}, ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs if language == 'auto': language = None self.last_language = language if self.last_language is not None: self.gen_kwargs["language"] = self.last_language self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=self.torch_dtype, ).to(device) # compile if self.compile_mode: self.model.generation_config.cache_implementation = "static" self.model.forward = torch.compile( self.model.forward, mode=self.compile_mode, fullgraph=True ) self.warmup() def prepare_model_inputs(self, spoken_prompt): input_features = self.processor( spoken_prompt, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(self.device, dtype=self.torch_dtype) return input_features def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 dummy_input = torch.randn( (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, device=self.device, ) if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation # hence, having min_new_tokens < max_new_tokens in the future doesn't make sense warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs[ "max_new_tokens" ], # Yes, assign max_new_tokens to min_new_tokens "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } else: warmup_gen_kwargs = self.gen_kwargs if self.device == "cuda": start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(n_steps): _ = self.model.generate(dummy_input, **warmup_gen_kwargs) if self.device == "cuda": end_event.record() torch.cuda.synchronize() logger.info( f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) def process(self, spoken_prompt): logger.debug("infering whisper...") console.print("infering whisper...") global pipeline_start pipeline_start = perf_counter() input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" if language_code not in SUPPORTED_LANGUAGES: # reprocess with the last language logger.warning("Whisper detected unsupported language:", language_code) console.print("Whisper detected unsupported language:", language_code) gen_kwargs = copy(self.gen_kwargs) gen_kwargs['language'] = self.last_language language_code = self.last_language pred_ids = self.model.generate(input_features, **gen_kwargs) else: self.last_language = language_code pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] language_code = self.processor.tokenizer.decode(pred_ids[0, 1])[2:-2] # remove "<|" and "|>" logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") console.print(f"Language Code Whisper: {language_code}") logger.debug(f"Language Code Whisper: {language_code}") yield (pred_text, language_code)