|
from typing import Dict, Any, List, Generator |
|
import torch |
|
import os |
|
import logging |
|
from s2s_pipeline import main, rename_args, parse_arguments, setup_logger, initialize_queues_and_events, build_pipeline |
|
import numpy as np |
|
from queue import Queue |
|
import threading |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
( |
|
self.module_kwargs, |
|
self.socket_receiver_kwargs, |
|
self.socket_sender_kwargs, |
|
self.vad_handler_kwargs, |
|
self.whisper_stt_handler_kwargs, |
|
self.paraformer_stt_handler_kwargs, |
|
self.language_model_handler_kwargs, |
|
self.mlx_language_model_handler_kwargs, |
|
self.parler_tts_handler_kwargs, |
|
self.melo_tts_handler_kwargs, |
|
self.chat_tts_handler_kwargs, |
|
) = parse_arguments() |
|
|
|
setup_logger(self.module_kwargs.log_level) |
|
|
|
rename_args(self.whisper_stt_handler_kwargs, "stt") |
|
rename_args(self.paraformer_stt_handler_kwargs, "paraformer_stt") |
|
rename_args(self.language_model_handler_kwargs, "lm") |
|
rename_args(self.mlx_language_model_handler_kwargs, "mlx_lm") |
|
rename_args(self.parler_tts_handler_kwargs, "tts") |
|
rename_args(self.melo_tts_handler_kwargs, "melo") |
|
rename_args(self.chat_tts_handler_kwargs, "chat_tts") |
|
|
|
self.queues_and_events = initialize_queues_and_events() |
|
|
|
self.pipeline_manager = build_pipeline( |
|
self.module_kwargs, |
|
self.socket_receiver_kwargs, |
|
self.socket_sender_kwargs, |
|
self.vad_handler_kwargs, |
|
self.whisper_stt_handler_kwargs, |
|
self.paraformer_stt_handler_kwargs, |
|
self.language_model_handler_kwargs, |
|
self.mlx_language_model_handler_kwargs, |
|
self.parler_tts_handler_kwargs, |
|
self.melo_tts_handler_kwargs, |
|
self.chat_tts_handler_kwargs, |
|
self.queues_and_events, |
|
) |
|
|
|
self.pipeline_manager.start() |
|
|
|
|
|
self.final_output_queue = Queue() |
|
|
|
|
|
self.output_collector_thread = threading.Thread(target=self._collect_output) |
|
self.output_collector_thread.start() |
|
|
|
def _collect_output(self): |
|
while True: |
|
output = self.queues_and_events['send_audio_chunks_queue'].get() |
|
if output == b"END": |
|
self.final_output_queue.put(b"END") |
|
break |
|
self.final_output_queue.put(output) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]: |
|
""" |
|
Args: |
|
data (Dict[str, Any]): The input data containing the necessary arguments. |
|
|
|
Returns: |
|
Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline. |
|
""" |
|
input_type = data.get("input_type", "text") |
|
input_data = data.get("input", "") |
|
|
|
if input_type == "speech": |
|
|
|
audio_array = np.frombuffer(input_data, dtype=np.int16) |
|
|
|
|
|
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes()) |
|
elif input_type == "text": |
|
|
|
self.queues_and_events['text_prompt_queue'].put(input_data) |
|
else: |
|
raise ValueError(f"Unsupported input type: {input_type}") |
|
|
|
|
|
while True: |
|
chunk = self.final_output_queue.get() |
|
if chunk == b"END": |
|
break |
|
yield {"output": chunk} |
|
|
|
def cleanup(self): |
|
|
|
self.pipeline_manager.stop() |
|
|
|
|
|
self.queues_and_events['send_audio_chunks_queue'].put(b"END") |
|
self.output_collector_thread.join() |