from typing import Dict, Any, List, Generator import torch import os import logging from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline import numpy as np from queue import Queue, Empty import threading import base64 import uuid class EndpointHandler: def __init__(self, path=""): ( self.module_kwargs, self.socket_receiver_kwargs, self.socket_sender_kwargs, self.vad_handler_kwargs, self.whisper_stt_handler_kwargs, self.paraformer_stt_handler_kwargs, self.language_model_handler_kwargs, self.mlx_language_model_handler_kwargs, self.parler_tts_handler_kwargs, self.melo_tts_handler_kwargs, self.chat_tts_handler_kwargs, ) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct') setup_logger(self.module_kwargs.log_level) prepare_all_args( self.module_kwargs, self.whisper_stt_handler_kwargs, self.paraformer_stt_handler_kwargs, self.language_model_handler_kwargs, self.mlx_language_model_handler_kwargs, self.parler_tts_handler_kwargs, self.melo_tts_handler_kwargs, self.chat_tts_handler_kwargs, ) self.queues_and_events = initialize_queues_and_events() self.pipeline_manager = build_pipeline( self.module_kwargs, self.socket_receiver_kwargs, self.socket_sender_kwargs, self.vad_handler_kwargs, self.whisper_stt_handler_kwargs, self.paraformer_stt_handler_kwargs, self.language_model_handler_kwargs, self.mlx_language_model_handler_kwargs, self.parler_tts_handler_kwargs, self.melo_tts_handler_kwargs, self.chat_tts_handler_kwargs, self.queues_and_events, ) self.pipeline_manager.start() # Add a new queue for collecting the final output self.final_output_queue = Queue() self.sessions = {} # Store session information self.vad_chunk_size = 512 # Set the chunk size required by the VAD model self.sample_rate = 16000 # Set the expected sample rate def _process_audio_chunk(self, audio_data: bytes, session_id: str): audio_array = np.frombuffer(audio_data, dtype=np.int16) # Ensure the audio is in chunks of the correct size chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)] for chunk in chunks: if len(chunk) == self.vad_chunk_size: self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes()) elif len(chunk) < self.vad_chunk_size: # Pad the last chunk if it's smaller than the required size padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant') self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes()) def _collect_output(self, session_id): while True: try: output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2) if isinstance(output, (str, bytes)) and output in (b"END", "END"): self.sessions[session_id]['status'] = 'completed' break elif isinstance(output, np.ndarray): self.sessions[session_id]['chunks'].append(output.tobytes()) else: self.sessions[session_id]['chunks'].append(output) except Empty: continue def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: request_type = data.get("request_type", "start") if request_type == "start": return self._handle_start_request(data) elif request_type == "continue": return self._handle_continue_request(data) else: raise ValueError(f"Unsupported request type: {request_type}") def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]: session_id = str(uuid.uuid4()) self.sessions[session_id] = { 'status': 'new', 'chunks': [], 'last_sent_index': 0, 'buffer': b'' # Add a buffer to store incomplete chunks } input_type = data.get("input_type", "text") input_data = data.get("inputs", "") if input_type == "speech": audio_bytes = base64.b64decode(input_data) self._process_audio_chunk(audio_bytes, session_id) elif input_type == "text": self.queues_and_events['text_prompt_queue'].put(input_data) else: raise ValueError(f"Unsupported input type: {input_type}") # Start output collection in a separate thread threading.Thread(target=self._collect_output, args=(session_id,)).start() return {"session_id": session_id, "status": "new"} def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]: session_id = data.get("session_id") if not session_id or session_id not in self.sessions: raise ValueError("Invalid or missing session_id") session = self.sessions[session_id] if not self.queues_and_events['should_listen'].is_set(): session['status'] = 'processing' elif "inputs" in data: # Handle additional input if provided input_data = data["inputs"] audio_bytes = base64.b64decode(input_data) self._process_audio_chunk(audio_bytes, session_id) chunks_to_send = session['chunks'][session['last_sent_index']:] session['last_sent_index'] = len(session['chunks']) if chunks_to_send: combined_audio = b''.join(chunks_to_send) base64_audio = base64.b64encode(combined_audio).decode('utf-8') return { "session_id": session_id, "status": session['status'], "output": base64_audio } else: return { "session_id": session_id, "status": session['status'], "output": None } def cleanup(self): # Stop the pipeline self.pipeline_manager.stop() # Stop the output collector thread self.queues_and_events['send_audio_chunks_queue'].put(b"END") self.output_collector_thread.join()