File size: 4,038 Bytes
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Dict, Any, List
import torch
import os
import logging
from s2s_pipeline import main, prepare_args, parse_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue
import threading

class EndpointHandler:
    def __init__(self, path=""):
        (
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        ) = parse_arguments()

        setup_logger(self.module_kwargs.log_level)

        prepare_args(self.whisper_stt_handler_kwargs, "stt")
        prepare_args(self.paraformer_stt_handler_kwargs, "paraformer_stt")
        prepare_args(self.language_model_handler_kwargs, "lm")
        prepare_args(self.mlx_language_model_handler_kwargs, "mlx_lm")
        prepare_args(self.parler_tts_handler_kwargs, "tts")
        prepare_args(self.melo_tts_handler_kwargs, "melo")
        prepare_args(self.chat_tts_handler_kwargs, "chat_tts")

        self.queues_and_events = initialize_queues_and_events()

        self.pipeline_manager = build_pipeline(
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
            self.queues_and_events,
        )

        self.pipeline_manager.start()

        # Add a new queue for collecting the final output
        self.final_output_queue = Queue()

        # Start a thread to collect the final output
        self.output_collector_thread = threading.Thread(target=self._collect_output)
        self.output_collector_thread.start()

    def _collect_output(self):
        while True:
            output = self.queues_and_events['send_audio_chunks_queue'].get()
            if output == b"END":
                self.final_output_queue.put(b"END")
                break
            self.final_output_queue.put(output)

    def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
        """
        Args:
            data (Dict[str, Any]): The input data containing the necessary arguments.
        
        Returns:
            Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
        """
        input_type = data.get("input_type", "text")
        input_data = data.get("input", "")

        if input_type == "speech":
            # Convert input audio data to numpy array
            audio_array = np.frombuffer(input_data, dtype=np.int16)
            
            # Put audio data into the recv_audio_chunks_queue
            self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
        elif input_type == "text":
            # Put text data directly into the text_prompt_queue
            self.queues_and_events['text_prompt_queue'].put(input_data)
        else:
            raise ValueError(f"Unsupported input type: {input_type}")

        # Stream the output chunks
        while True:
            chunk = self.final_output_queue.get()
            if chunk == b"END":
                break
            yield {"output": chunk}

    def cleanup(self):
        # Stop the pipeline
        self.pipeline_manager.stop()
        
        # Stop the output collector thread
        self.queues_and_events['send_audio_chunks_queue'].put(b"END")
        self.output_collector_thread.join()