File size: 6,797 Bytes
ec7729d
 
 
967aebb
c72e80d
 
 
2d00549
c72e80d
2d00549
c72e80d
c4bb76f
3abafc4
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1757ca7
c72e80d
 
2d00549
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3abafc4
9a5a5b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72e80d
3abafc4
c72e80d
2d00549
3abafc4
2d00549
3abafc4
2d00549
 
3abafc4
2d00549
3abafc4
2d00549
0d00307
c72e80d
3abafc4
 
c72e80d
3abafc4
 
 
 
 
 
 
 
 
 
9a5a5b3
3abafc4
9a5a5b3
 
3abafc4
2d00549
c72e80d
f6f039f
c72e80d
 
0d00307
9a5a5b3
c72e80d
 
9a5a5b3
c72e80d
 
3abafc4
 
 
9a5a5b3
2d00549
3abafc4
 
 
 
2d00549
3abafc4
0d00307
9a5a5b3
 
 
0d00307
 
9a5a5b3
0d00307
3abafc4
 
c4bb76f
3abafc4
 
 
 
 
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import subprocess
subprocess.run("pip install flash-attn --no-build-isolation", shell=True, check=True)
       
from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue, Empty
import threading
import base64
import uuid

class EndpointHandler:
    def __init__(self, path=""):
        (
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        ) = get_default_arguments(mode='none', log_level='DEBUG', lm_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct')
        setup_logger(self.module_kwargs.log_level)

        prepare_all_args(
            self.module_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        )

        self.queues_and_events = initialize_queues_and_events()

        self.pipeline_manager = build_pipeline(
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
            self.queues_and_events,
        )

        self.pipeline_manager.start()

        # Add a new queue for collecting the final output
        self.final_output_queue = Queue()
        self.sessions = {}  # Store session information
        self.vad_chunk_size = 512  # Set the chunk size required by the VAD model
        self.sample_rate = 16000  # Set the expected sample rate

    def _process_audio_chunk(self, audio_data: bytes, session_id: str):
        audio_array = np.frombuffer(audio_data, dtype=np.int16)

        # Ensure the audio is in chunks of the correct size
        chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)]
        
        for chunk in chunks:
            if len(chunk) == self.vad_chunk_size:
                self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes())
            elif len(chunk) < self.vad_chunk_size:
                # Pad the last chunk if it's smaller than the required size
                padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant')
                self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes())

    def _collect_output(self, session_id):
        while True:
            try:
                output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=2)
                if isinstance(output, (str, bytes)) and output in (b"END", "END"):
                    self.sessions[session_id]['status'] = 'completed'
                    break
                elif isinstance(output, np.ndarray):
                    self.sessions[session_id]['chunks'].append(output.tobytes())
                else:
                    self.sessions[session_id]['chunks'].append(output)
            except Empty:
                continue

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        request_type = data.get("request_type", "start")
        
        if request_type == "start":
            return self._handle_start_request(data)
        elif request_type == "continue":
            return self._handle_continue_request(data)
        else:
            raise ValueError(f"Unsupported request type: {request_type}")

    def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
        session_id = str(uuid.uuid4())
        self.sessions[session_id] = {
            'status': 'new',
            'chunks': [],
            'last_sent_index': 0,
            'buffer': b''  # Add a buffer to store incomplete chunks
        }

        input_type = data.get("input_type", "text")
        input_data = data.get("inputs", "")

        if input_type == "speech":
            audio_bytes = base64.b64decode(input_data)
            self._process_audio_chunk(audio_bytes, session_id)
        elif input_type == "text":
            self.queues_and_events['text_prompt_queue'].put(input_data)
        else:
            raise ValueError(f"Unsupported input type: {input_type}")

        # Start output collection in a separate thread
        threading.Thread(target=self._collect_output, args=(session_id,)).start()

        return {"session_id": session_id, "status": "new"}

    def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
        session_id = data.get("session_id")
        if not session_id or session_id not in self.sessions:
            raise ValueError("Invalid or missing session_id")

        session = self.sessions[session_id]

        if not self.queues_and_events['should_listen'].is_set():
            session['status'] = 'processing'
        elif "inputs" in data:  # Handle additional input if provided  
            input_data = data["inputs"]
            audio_bytes = base64.b64decode(input_data)
            self._process_audio_chunk(audio_bytes, session_id)

        chunks_to_send = session['chunks'][session['last_sent_index']:]
        session['last_sent_index'] = len(session['chunks'])

        if chunks_to_send:
            combined_audio = b''.join(chunks_to_send)
            base64_audio = base64.b64encode(combined_audio).decode('utf-8')
            return {
                "session_id": session_id,
                "status": session['status'],
                "output": base64_audio
            }
        else:
            return {
                "session_id": session_id,
                "status": session['status'],
                "output": None
            }

    def cleanup(self):
        # Stop the pipeline
        self.pipeline_manager.stop()
        
        # Stop the output collector thread
        self.queues_and_events['send_audio_chunks_queue'].put(b"END")
        self.output_collector_thread.join()