|
import torchaudio |
|
from VAD.vad_iterator import VADIterator |
|
from baseHandler import BaseHandler |
|
import numpy as np |
|
import torch |
|
from rich.console import Console |
|
|
|
from utils.utils import int2float |
|
from df.enhance import enhance, init_df |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
|
|
class VADHandler(BaseHandler): |
|
""" |
|
Handles voice activity detection. When voice activity is detected, audio will be accumulated until the end of speech is detected and then passed |
|
to the following part. |
|
""" |
|
|
|
def setup( |
|
self, |
|
should_listen, |
|
thresh=0.3, |
|
sample_rate=16000, |
|
min_silence_ms=1000, |
|
min_speech_ms=500, |
|
max_speech_ms=float("inf"), |
|
speech_pad_ms=30, |
|
audio_enhancement=False, |
|
): |
|
self.should_listen = should_listen |
|
self.sample_rate = sample_rate |
|
self.min_silence_ms = min_silence_ms |
|
self.min_speech_ms = min_speech_ms |
|
self.max_speech_ms = max_speech_ms |
|
self.model, _ = torch.hub.load("snakers4/silero-vad", "silero_vad") |
|
self.iterator = VADIterator( |
|
self.model, |
|
threshold=thresh, |
|
sampling_rate=sample_rate, |
|
min_silence_duration_ms=min_silence_ms, |
|
speech_pad_ms=speech_pad_ms, |
|
) |
|
self.audio_enhancement = audio_enhancement |
|
if audio_enhancement: |
|
self.enhanced_model, self.df_state, _ = init_df() |
|
|
|
def process(self, audio_chunk): |
|
audio_int16 = np.frombuffer(audio_chunk, dtype=np.int16) |
|
audio_float32 = int2float(audio_int16) |
|
vad_output = self.iterator(torch.from_numpy(audio_float32)) |
|
if vad_output is not None and len(vad_output) != 0: |
|
console.print("VAD: end of speech detected") |
|
logger.debug("VAD: end of speech detected") |
|
array = torch.cat(vad_output).cpu().numpy() |
|
duration_ms = len(array) / self.sample_rate * 1000 |
|
if duration_ms < self.min_speech_ms or duration_ms > self.max_speech_ms: |
|
console.print( |
|
f"audio input of duration: {len(array) / self.sample_rate}s, skipping" |
|
) |
|
logger.debug( |
|
f"audio input of duration: {len(array) / self.sample_rate}s, skipping" |
|
) |
|
else: |
|
self.should_listen.clear() |
|
logger.debug("Stop listening") |
|
if self.audio_enhancement: |
|
if self.sample_rate != self.df_state.sr(): |
|
audio_float32 = torchaudio.functional.resample( |
|
torch.from_numpy(array), |
|
orig_freq=self.sample_rate, |
|
new_freq=self.df_state.sr(), |
|
) |
|
enhanced = enhance( |
|
self.enhanced_model, |
|
self.df_state, |
|
audio_float32.unsqueeze(0), |
|
) |
|
enhanced = torchaudio.functional.resample( |
|
enhanced, |
|
orig_freq=self.df_state.sr(), |
|
new_freq=self.sample_rate, |
|
) |
|
else: |
|
enhanced = enhance( |
|
self.enhanced_model, self.df_state, audio_float32 |
|
) |
|
array = enhanced.numpy().squeeze() |
|
yield array |
|
|
|
@property |
|
def min_time_to_debug(self): |
|
return 0.00001 |
|
|