File size: 5,966 Bytes
0d00307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import threading
from queue import Queue
import sounddevice as sd
import numpy as np
import requests
import base64
import time
from dataclasses import dataclass, field

@dataclass
class AudioStreamingClientArguments:
    sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
    chunk_size: int = field(default=1024, metadata={"help": "The size of audio chunks in samples. Default is 1024."})
    api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
    auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})

class AudioStreamingClient:
    def __init__(self, args: AudioStreamingClientArguments):
        self.args = args
        self.stop_event = threading.Event()
        self.send_queue = Queue()
        self.recv_queue = Queue()
        self.session_id = None
        self.headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {self.args.auth_token}",
            "Content-Type": "application/json"
        }

    def start(self):
        print("Starting audio streaming...")
        
        send_thread = threading.Thread(target=self.send_audio)
        recv_thread = threading.Thread(target=self.receive_audio)
        play_thread = threading.Thread(target=self.play_audio)

        with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback):
            send_thread.start()
            recv_thread.start()
            play_thread.start()

            try:
                input("Press Enter to stop streaming...")
            except KeyboardInterrupt:
                print("\nStreaming interrupted by user.")
            finally:
                self.stop_event.set()
                send_thread.join()
                recv_thread.join()
                play_thread.join()
                print("Audio streaming stopped.")

    def audio_callback(self, indata, frames, time, status):
        self.send_queue.put(indata.copy())

    def send_audio(self):
        buffer = b''
        while not self.stop_event.is_set():
            if not self.send_queue.empty():
                chunk = self.send_queue.get().tobytes()
                buffer += chunk
                if len(buffer) >= self.args.chunk_size * 2:  # * 2 because of int16
                    self.send_request(buffer)
                    buffer = b''
            else:
                time.sleep(0.01)

    def send_request(self, audio_data):
        if not self.session_id:
            payload = {
                "request_type": "start",
                "inputs": base64.b64encode(audio_data).decode('utf-8'),
                "input_type": "speech",
            }
        else:
            payload = {
                "request_type": "continue",
                "session_id": self.session_id,
                "inputs": base64.b64encode(audio_data).decode('utf-8'),
            }

        try:
            response = requests.post(self.args.api_url, headers=self.headers, json=payload)
            response_data = response.json()

            if "session_id" in response_data:
                self.session_id = response_data["session_id"]

            if "output" in response_data and response_data["output"]:
                audio_bytes = base64.b64decode(response_data["output"])
                audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
                self.recv_queue.put(audio_np)

        except Exception as e:
            print(f"Error sending request: {e}")

    def receive_audio(self):
        while not self.stop_event.is_set():
            if self.session_id:
                payload = {
                    "request_type": "continue",
                    "session_id": self.session_id
                }
                try:
                    response = requests.post(self.args.api_url, headers=self.headers, json=payload)
                    response_data = response.json()

                    if response_data["status"] == "completed" and not response_data["output"]:
                        break

                    if response_data["output"]:
                        audio_bytes = base64.b64decode(response_data["output"])
                        audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
                        self.recv_queue.put(audio_np)

                except Exception as e:
                    print(f"Error receiving audio: {e}")

            time.sleep(0.1)

    def play_audio(self):
        def audio_callback(outdata, frames, time, status):
            if not self.recv_queue.empty():
                chunk = self.recv_queue.get()
                if len(chunk) < len(outdata):
                    outdata[:len(chunk)] = chunk.reshape(-1, 1)
                    outdata[len(chunk):] = 0
                else:
                    outdata[:] = chunk[:len(outdata)].reshape(-1, 1)
            else:
                outdata[:] = 0

        with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, callback=audio_callback):
            while not self.stop_event.is_set():
                time.sleep(0.1)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Audio Streaming Client")
    parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.")
    parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.")
    parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.")
    parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.")

    args = parser.parse_args()
    client_args = AudioStreamingClientArguments(**vars(args))
    client = AudioStreamingClient(client_args)
    client.start()