s2s / handler.py
andito's picture
andito HF staff
Upload folder using huggingface_hub
967aebb verified
raw
history blame
4.04 kB
from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, rename_args, parse_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue
import threading
class EndpointHandler:
def __init__(self, path=""):
(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
) = parse_arguments()
setup_logger(self.module_kwargs.log_level)
rename_args(self.whisper_stt_handler_kwargs, "stt")
rename_args(self.paraformer_stt_handler_kwargs, "paraformer_stt")
rename_args(self.language_model_handler_kwargs, "lm")
rename_args(self.mlx_language_model_handler_kwargs, "mlx_lm")
rename_args(self.parler_tts_handler_kwargs, "tts")
rename_args(self.melo_tts_handler_kwargs, "melo")
rename_args(self.chat_tts_handler_kwargs, "chat_tts")
self.queues_and_events = initialize_queues_and_events()
self.pipeline_manager = build_pipeline(
self.module_kwargs,
self.socket_receiver_kwargs,
self.socket_sender_kwargs,
self.vad_handler_kwargs,
self.whisper_stt_handler_kwargs,
self.paraformer_stt_handler_kwargs,
self.language_model_handler_kwargs,
self.mlx_language_model_handler_kwargs,
self.parler_tts_handler_kwargs,
self.melo_tts_handler_kwargs,
self.chat_tts_handler_kwargs,
self.queues_and_events,
)
self.pipeline_manager.start()
# Add a new queue for collecting the final output
self.final_output_queue = Queue()
# Start a thread to collect the final output
self.output_collector_thread = threading.Thread(target=self._collect_output)
self.output_collector_thread.start()
def _collect_output(self):
while True:
output = self.queues_and_events['send_audio_chunks_queue'].get()
if output == b"END":
self.final_output_queue.put(b"END")
break
self.final_output_queue.put(output)
def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
"""
Args:
data (Dict[str, Any]): The input data containing the necessary arguments.
Returns:
Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
"""
input_type = data.get("input_type", "text")
input_data = data.get("input", "")
if input_type == "speech":
# Convert input audio data to numpy array
audio_array = np.frombuffer(input_data, dtype=np.int16)
# Put audio data into the recv_audio_chunks_queue
self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
elif input_type == "text":
# Put text data directly into the text_prompt_queue
self.queues_and_events['text_prompt_queue'].put(input_data)
else:
raise ValueError(f"Unsupported input type: {input_type}")
# Stream the output chunks
while True:
chunk = self.final_output_queue.get()
if chunk == b"END":
break
yield {"output": chunk}
def cleanup(self):
# Stop the pipeline
self.pipeline_manager.stop()
# Stop the output collector thread
self.queues_and_events['send_audio_chunks_queue'].put(b"END")
self.output_collector_thread.join()