Spaces:
Sleeping
Sleeping
import io | |
import wave | |
import numpy as np | |
import requests | |
from openai import OpenAI | |
import webrtcvad | |
from transformers import pipeline | |
from typing import List, Optional, Generator, Tuple, Any | |
from utils.errors import APIError, AudioConversionError | |
SAMPLE_RATE: int = 48000 | |
FRAME_DURATION: int = 30 | |
def detect_voice(audio: np.ndarray, sample_rate: int = SAMPLE_RATE, frame_duration: int = FRAME_DURATION) -> bool: | |
""" | |
Detect voice activity in the given audio data. | |
Args: | |
audio (np.ndarray): Audio data as a numpy array. | |
sample_rate (int): Sample rate of the audio. Defaults to SAMPLE_RATE. | |
frame_duration (int): Duration of each frame in milliseconds. Defaults to FRAME_DURATION. | |
Returns: | |
bool: True if voice activity is detected, False otherwise. | |
""" | |
vad = webrtcvad.Vad(3) # Aggressiveness mode: 3 (most aggressive) | |
audio_bytes = audio.tobytes() | |
num_samples_per_frame = int(sample_rate * frame_duration / 1000) | |
frames = [audio_bytes[i : i + num_samples_per_frame * 2] for i in range(0, len(audio_bytes), num_samples_per_frame * 2)] | |
count_speech = 0 | |
for frame in frames: | |
if len(frame) < num_samples_per_frame * 2: | |
continue | |
if vad.is_speech(frame, sample_rate): | |
count_speech += 1 | |
if count_speech > 6: | |
return True | |
return False | |
class STTManager: | |
"""Manages speech-to-text operations.""" | |
def __init__(self, config: Any): | |
""" | |
Initialize the STTManager. | |
Args: | |
config (Any): Configuration object containing STT settings. | |
""" | |
self.config = config | |
self.SAMPLE_RATE: int = SAMPLE_RATE | |
self.CHUNK_LENGTH: int = 5 | |
self.STEP_LENGTH: int = 3 | |
self.MAX_RELIABILITY_CUTOFF: int = self.CHUNK_LENGTH - 1 | |
self.status: bool = self.test_stt() | |
self.streaming: bool = self.status | |
if config.stt.type == "HF_LOCAL": | |
self.pipe = pipeline("automatic-speech-recognition", model=config.stt.name) | |
def numpy_audio_to_bytes(self, audio_data: np.ndarray) -> bytes: | |
""" | |
Convert numpy array audio data to bytes. | |
Args: | |
audio_data (np.ndarray): Audio data as a numpy array. | |
Returns: | |
bytes: Audio data as bytes. | |
Raises: | |
AudioConversionError: If there's an error during conversion. | |
""" | |
buffer = io.BytesIO() | |
try: | |
with wave.open(buffer, "wb") as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(self.SAMPLE_RATE) | |
wf.writeframes(audio_data.tobytes()) | |
except Exception as e: | |
raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}") | |
return buffer.getvalue() | |
def process_audio_chunk(self, audio: Tuple[int, np.ndarray], audio_buffer: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
Process an audio chunk and update the audio buffer. | |
Args: | |
audio (Tuple[int, np.ndarray]): Audio chunk data. | |
audio_buffer (np.ndarray): Existing audio buffer. | |
Returns: | |
Tuple[np.ndarray, np.ndarray]: Updated audio buffer and processed audio. | |
""" | |
has_voice = detect_voice(audio[1]) | |
ended = len(audio[1]) % 24000 != 0 | |
if has_voice: | |
audio_buffer = np.concatenate((audio_buffer, audio[1])) | |
is_short = len(audio_buffer) / self.SAMPLE_RATE < 1.0 | |
if is_short or (has_voice and not ended): | |
return audio_buffer, np.array([], dtype=np.int16) | |
return np.array([], dtype=np.int16), audio_buffer | |
def transcribe_audio(self, audio: np.ndarray, text: str = "") -> str: | |
""" | |
Transcribe audio data and append to existing text. | |
Args: | |
audio (np.ndarray): Audio data to transcribe. | |
text (str): Existing text to append to. Defaults to empty string. | |
Returns: | |
str: Transcribed text appended to existing text. | |
""" | |
if len(audio) < 500: | |
return text | |
transcript = self.transcribe_numpy_array(audio, context=text) | |
return f"{text} {transcript}".strip() | |
def transcribe_and_add_to_chat(self, audio: np.ndarray, chat: List[List[Optional[str]]]) -> List[List[Optional[str]]]: | |
""" | |
Transcribe audio and add the result to the chat history. | |
Args: | |
audio (np.ndarray): Audio data to transcribe. | |
chat (List[List[Optional[str]]]): Existing chat history. | |
Returns: | |
List[List[Optional[str]]]: Updated chat history with transcribed text. | |
""" | |
text = self.transcribe_audio(audio) | |
return self.add_to_chat(text, chat) | |
def add_to_chat(self, text: str, chat: List[List[Optional[str]]]) -> List[List[Optional[str]]]: | |
""" | |
Add text to the chat history. | |
Args: | |
text (str): Text to add to chat. | |
chat (List[List[Optional[str]]]): Existing chat history. | |
editable_chat (bool): Whether the chat is editable. Defaults to True. | |
Returns: | |
List[List[Optional[str]]]: Updated chat history. | |
""" | |
if not text: | |
return chat | |
if not chat or chat[-1][0] is None: | |
chat.append(["", None]) | |
chat[-1][0] = text | |
return chat | |
def transcribe_numpy_array(self, audio: np.ndarray, context: Optional[str] = None) -> str: | |
""" | |
Transcribe audio data using the configured STT service. | |
Args: | |
audio (np.ndarray): Audio data as a numpy array. | |
context (Optional[str]): Optional context for transcription. | |
Returns: | |
str: Transcribed text. | |
Raises: | |
APIError: If there's an unexpected error during transcription. | |
""" | |
transcription_methods = { | |
"OPENAI_API": self._transcribe_openai, | |
"HF_API": self._transcribe_hf_api, | |
"HF_LOCAL": self._transcribe_hf_local, | |
} | |
try: | |
transcribe_method = transcription_methods.get(self.config.stt.type) | |
if transcribe_method: | |
return transcribe_method(audio, context) | |
else: | |
raise APIError(f"Unsupported STT type: {self.config.stt.type}") | |
except Exception as e: | |
raise APIError(f"STT Error: Unexpected error: {e}") | |
def _transcribe_openai(self, audio: np.ndarray, context: Optional[str]) -> str: | |
""" | |
Transcribe audio using OpenAI API. | |
Args: | |
audio (np.ndarray): Audio data as a numpy array. | |
context (Optional[str]): Optional context for transcription. | |
Returns: | |
str: Transcribed text. | |
""" | |
audio_bytes = self.numpy_audio_to_bytes(audio) | |
data = ("temp.wav", audio_bytes, "audio/wav") | |
client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key) | |
return client.audio.transcriptions.create(model=self.config.stt.name, file=data, response_format="text", prompt=context) | |
def _transcribe_hf_api(self, audio: np.ndarray, _context: Optional[str]) -> str: | |
""" | |
Transcribe audio using Hugging Face API. | |
Args: | |
audio (np.ndarray): Audio data as a numpy array. | |
_context (Optional[str]): Unused context parameter. | |
Returns: | |
str: Transcribed text. | |
Raises: | |
APIError: If there's an error in the API response. | |
""" | |
audio_bytes = self.numpy_audio_to_bytes(audio) | |
headers = {"Authorization": f"Bearer {self.config.stt.key}"} | |
response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes) | |
if response.status_code != 200: | |
error_details = response.json().get("error", "No error message provided") | |
raise APIError("STT Error: HF API error", status_code=response.status_code, details=error_details) | |
transcription = response.json().get("text") | |
if transcription is None: | |
raise APIError("STT Error: No transcription returned by HF API") | |
return transcription | |
def _transcribe_hf_local(self, audio: np.ndarray, _context: Optional[str]) -> str: | |
""" | |
Transcribe audio using local Hugging Face model. | |
Args: | |
audio (np.ndarray): Audio data as a numpy array. | |
_context (Optional[str]): Unused context parameter. | |
Returns: | |
str: Transcribed text. | |
""" | |
result = self.pipe({"sampling_rate": self.SAMPLE_RATE, "raw": audio.astype(np.float32) / 32768.0}) | |
return result["text"] | |
def test_stt(self) -> bool: | |
""" | |
Test the STT functionality. | |
Returns: | |
bool: True if the test is successful, False otherwise. | |
""" | |
try: | |
self.transcribe_audio(np.zeros(10000)) | |
return True | |
except: | |
return False | |
class TTSManager: | |
"""Manages text-to-speech operations.""" | |
def __init__(self, config: Any): | |
""" | |
Initialize the TTSManager. | |
Args: | |
config (Any): Configuration object containing TTS settings. | |
""" | |
self.config = config | |
self.SAMPLE_RATE: int = SAMPLE_RATE | |
self.status: bool = self.test_tts(stream=False) | |
self.streaming: bool = self.test_tts(stream=True) if self.status else False | |
def test_tts(self, stream: bool) -> bool: | |
""" | |
Test the TTS functionality. | |
Args: | |
stream (bool): Whether to test streaming TTS. | |
Returns: | |
bool: True if the test is successful, False otherwise. | |
""" | |
try: | |
list(self.read_text("Handshake", stream=stream)) | |
return True | |
except: | |
return False | |
def read_text(self, text: str, stream: Optional[bool] = None) -> Generator[bytes, None, None]: | |
""" | |
Convert text to speech using the configured TTS service. | |
Args: | |
text (str): Text to convert to speech. | |
stream (Optional[bool]): Whether to stream the audio. Defaults to self.streaming if not provided. | |
Yields: | |
bytes: Audio data in bytes. | |
Raises: | |
APIError: If there's an unexpected error during text-to-speech conversion. | |
""" | |
if not text: | |
yield b"" | |
return | |
stream = self.streaming if stream is None else stream | |
headers = {"Authorization": f"Bearer {self.config.tts.key}"} | |
data = {"model": self.config.tts.name, "input": text, "voice": "alloy", "response_format": "opus"} | |
try: | |
yield from self._read_text_stream(headers, data) if stream else self._read_text_non_stream(headers, data) | |
except APIError: | |
raise | |
except Exception as e: | |
raise APIError(f"TTS Error: Unexpected error: {e}") | |
def _read_text_non_stream(self, headers: dict, data: dict) -> Generator[bytes, None, None]: | |
""" | |
Handle non-streaming TTS requests. | |
Args: | |
headers (dict): Request headers. | |
data (dict): Request data. | |
Yields: | |
bytes: Audio data in bytes. | |
Raises: | |
APIError: If there's an error in the API response. | |
""" | |
if self.config.tts.type == "OPENAI_API": | |
url = f"{self.config.tts.url}/audio/speech" | |
elif self.config.tts.type == "HF_API": | |
url = self.config.tts.url | |
data = {"inputs": data["input"]} | |
else: | |
raise APIError(f"TTS Error: Unsupported TTS type: {self.config.tts.type}") | |
response = requests.post(url, headers=headers, json=data) | |
if response.status_code != 200: | |
error_details = response.json().get("error", "No error message provided") | |
raise APIError(f"TTS Error: {self.config.tts.type} error", status_code=response.status_code, details=error_details) | |
yield response.content | |
def _read_text_stream(self, headers: dict, data: dict) -> Generator[bytes, None, None]: | |
""" | |
Handle streaming TTS requests. | |
Args: | |
headers (dict): Request headers. | |
data (dict): Request data. | |
Yields: | |
bytes: Audio data in bytes. | |
Raises: | |
APIError: If there's an error in the API response or if streaming is not supported. | |
""" | |
if self.config.tts.type != "OPENAI_API": | |
raise APIError("TTS Error: Streaming not supported for this TTS type") | |
url = f"{self.config.tts.url}/audio/speech" | |
with requests.post(url, headers=headers, json=data, stream=True) as response: | |
if response.status_code != 200: | |
error_details = response.json().get("error", "No error message provided") | |
raise APIError("TTS Error: OPENAI API error", status_code=response.status_code, details=error_details) | |
yield from response.iter_content(chunk_size=1024) | |
def read_last_message(self, chat_history: List[List[Optional[str]]]) -> Generator[bytes, None, None]: | |
""" | |
Read the last message in the chat history. | |
Args: | |
chat_history (List[List[Optional[str]]]): Chat history. | |
Yields: | |
bytes: Audio data for the last message. | |
""" | |
if chat_history and chat_history[-1][1]: | |
yield from self.read_text(chat_history[-1][1]) | |