File size: 6,686 Bytes
0d00307
 
 
 
 
 
 
 
 
 
 
 
9a5a5b3
0d00307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5a5b3
0d00307
 
 
 
 
 
 
9a5a5b3
0d00307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5a5b3
0d00307
 
6383c51
 
 
 
 
 
 
 
 
0d00307
9a5a5b3
 
 
 
8745348
 
9a5a5b3
 
 
 
0d00307
9a5a5b3
 
 
0d00307
9a5a5b3
0d00307
 
 
 
 
 
 
 
9a5a5b3
 
 
 
0d00307
9a5a5b3
 
0d00307
 
9a5a5b3
 
 
 
0d00307
8745348
 
 
 
 
 
 
 
 
0d00307
 
9a5a5b3
0d00307
 
 
 
 
9a5a5b3
 
 
 
 
 
 
0d00307
9a5a5b3
0d00307
 
 
9a5a5b3
0d00307
9a5a5b3
0d00307
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import threading
from queue import Queue
import sounddevice as sd
import numpy as np
import requests
import base64
import time
from dataclasses import dataclass, field

@dataclass
class AudioStreamingClientArguments:
    sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
    chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 1024."})
    api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
    auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})

class AudioStreamingClient:
    def __init__(self, args: AudioStreamingClientArguments):
        self.args = args
        self.stop_event = threading.Event()
        self.send_queue = Queue()
        self.recv_queue = Queue()
        self.session_id = None
        self.headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {self.args.auth_token}",
            "Content-Type": "application/json"
        }
        self.session_state = "idle"  # Possible states: idle, sending, processing, waiting

    def start(self):
        print("Starting audio streaming...")
        
        send_thread = threading.Thread(target=self.send_audio)
        play_thread = threading.Thread(target=self.play_audio)

        with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback, blocksize=self.args.chunk_size):
            send_thread.start()
            play_thread.start()

            try:
                input("Press Enter to stop streaming...")
            except KeyboardInterrupt:
                print("\nStreaming interrupted by user.")
            finally:
                self.stop_event.set()
                send_thread.join()
                play_thread.join()
                print("Audio streaming stopped.")

    def audio_callback(self, indata, frames, time, status):
        self.send_queue.put(indata.copy())

    def send_audio(self):
        buffer = b''
        while not self.stop_event.is_set():
            if self.session_state != "processing" and not self.send_queue.empty():
                chunk = self.send_queue.get().tobytes()
                buffer += chunk
                
                # Calculate energy of the audio chunk
                energy = np.sum(np.square(np.frombuffer(chunk, dtype=np.int16))) / len(chunk)
                print(f"Energy: {energy}")
                
                if energy > 0.01:  # Threshold for energy detection
                    if len(buffer) >= self.args.chunk_size * 2:  # * 2 because of int16
                        self.send_request(buffer)
                        buffer = b''
            else:
                self.send_request()
                time.sleep(0.1)

    def send_request(self, audio_data=None):
        payload = {"input_type": "speech",
                   "inputs": ""}

        if audio_data is not None:
            print("Sending audio data")
            payload["inputs"] = base64.b64encode(audio_data).decode('utf-8')

        if self.session_id:
            payload["session_id"] = self.session_id
            payload["request_type"] = "continue"
        else:
            payload["request_type"] = "start"

        try:
            response = requests.post(self.args.api_url, headers=self.headers, json=payload)
            response_data = response.json()

            if "session_id" in response_data:
                self.session_id = response_data["session_id"]

            if "status" in response_data and response_data["status"] == "processing":
                print("Processing audio data")
                self.session_state = "processing"

            if "output" in response_data and response_data["output"]:
                print("Received audio data")
                self.session_state = "processing"  # Set state to processing when we start receiving audio
                audio_bytes = base64.b64decode(response_data["output"])
                audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
                # Split the audio into smaller chunks for playback
                for i in range(0, len(audio_np), self.args.chunk_size):
                    chunk = audio_np[i:i+self.args.chunk_size]
                    self.recv_queue.put(chunk)

            if "status" in response_data and response_data["status"] == "completed":
                print("Completed audio processing")
                self.session_state = None
                self.session_id = None
                while not self.recv_queue.empty():
                    time.sleep(0.01)  # wait for the queue to empty
                while not self.send_queue.empty():
                    _ = self.send_queue.get()  # Clear the queue

        except Exception as e:
            print(f"Error sending request: {e}")
            self.session_state = "idle"  # Reset state to idle in case of error

    def play_audio(self):
        def audio_callback(outdata, frames, time, status):
            if not self.recv_queue.empty():
                chunk = self.recv_queue.get()
                
                # Ensure chunk is int16 and clip to valid range
                chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)
                
                if len(chunk_int16) < len(outdata):
                    outdata[:len(chunk_int16), 0] = chunk_int16
                    outdata[len(chunk_int16):] = 0
                else:
                    outdata[:, 0] = chunk_int16[:len(outdata)]
            else:
                outdata[:] = 0

        with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=audio_callback, blocksize=self.args.chunk_size):
            while not self.stop_event.is_set():
                time.sleep(0.01)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Audio Streaming Client")
    parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.")
    parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.")
    parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.")
    parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.")

    args = parser.parse_args()
    client_args = AudioStreamingClientArguments(**vars(args))
    client = AudioStreamingClient(client_args)
    client.start()