andito HF staff commited on
Commit
9a5a5b3
1 Parent(s): 0d00307

Upload folder using huggingface_hub

Browse files
audio_streaming_client.py CHANGED
@@ -10,7 +10,7 @@ from dataclasses import dataclass, field
10
  @dataclass
11
  class AudioStreamingClientArguments:
12
  sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
13
- chunk_size: int = field(default=1024, metadata={"help": "The size of audio chunks in samples. Default is 1024."})
14
  api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
15
  auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})
16
 
@@ -26,17 +26,16 @@ class AudioStreamingClient:
26
  "Authorization": f"Bearer {self.args.auth_token}",
27
  "Content-Type": "application/json"
28
  }
 
29
 
30
  def start(self):
31
  print("Starting audio streaming...")
32
 
33
  send_thread = threading.Thread(target=self.send_audio)
34
- recv_thread = threading.Thread(target=self.receive_audio)
35
  play_thread = threading.Thread(target=self.play_audio)
36
 
37
- with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback):
38
  send_thread.start()
39
- recv_thread.start()
40
  play_thread.start()
41
 
42
  try:
@@ -46,7 +45,6 @@ class AudioStreamingClient:
46
  finally:
47
  self.stop_event.set()
48
  send_thread.join()
49
- recv_thread.join()
50
  play_thread.join()
51
  print("Audio streaming stopped.")
52
 
@@ -56,28 +54,29 @@ class AudioStreamingClient:
56
  def send_audio(self):
57
  buffer = b''
58
  while not self.stop_event.is_set():
59
- if not self.send_queue.empty():
60
  chunk = self.send_queue.get().tobytes()
61
  buffer += chunk
62
  if len(buffer) >= self.args.chunk_size * 2: # * 2 because of int16
63
  self.send_request(buffer)
64
  buffer = b''
65
  else:
66
- time.sleep(0.01)
 
 
 
 
 
 
 
 
 
67
 
68
- def send_request(self, audio_data):
69
- if not self.session_id:
70
- payload = {
71
- "request_type": "start",
72
- "inputs": base64.b64encode(audio_data).decode('utf-8'),
73
- "input_type": "speech",
74
- }
75
  else:
76
- payload = {
77
- "request_type": "continue",
78
- "session_id": self.session_id,
79
- "inputs": base64.b64encode(audio_data).decode('utf-8'),
80
- }
81
 
82
  try:
83
  response = requests.post(self.args.api_url, headers=self.headers, json=payload)
@@ -86,53 +85,48 @@ class AudioStreamingClient:
86
  if "session_id" in response_data:
87
  self.session_id = response_data["session_id"]
88
 
 
 
 
 
 
 
 
 
 
89
  if "output" in response_data and response_data["output"]:
 
 
90
  audio_bytes = base64.b64decode(response_data["output"])
91
  audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
92
- self.recv_queue.put(audio_np)
 
 
 
93
 
94
  except Exception as e:
95
  print(f"Error sending request: {e}")
96
-
97
- def receive_audio(self):
98
- while not self.stop_event.is_set():
99
- if self.session_id:
100
- payload = {
101
- "request_type": "continue",
102
- "session_id": self.session_id
103
- }
104
- try:
105
- response = requests.post(self.args.api_url, headers=self.headers, json=payload)
106
- response_data = response.json()
107
-
108
- if response_data["status"] == "completed" and not response_data["output"]:
109
- break
110
-
111
- if response_data["output"]:
112
- audio_bytes = base64.b64decode(response_data["output"])
113
- audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
114
- self.recv_queue.put(audio_np)
115
-
116
- except Exception as e:
117
- print(f"Error receiving audio: {e}")
118
-
119
- time.sleep(0.1)
120
 
121
  def play_audio(self):
122
  def audio_callback(outdata, frames, time, status):
123
  if not self.recv_queue.empty():
124
  chunk = self.recv_queue.get()
125
- if len(chunk) < len(outdata):
126
- outdata[:len(chunk)] = chunk.reshape(-1, 1)
127
- outdata[len(chunk):] = 0
 
 
 
 
128
  else:
129
- outdata[:] = chunk[:len(outdata)].reshape(-1, 1)
130
  else:
131
  outdata[:] = 0
132
 
133
- with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, callback=audio_callback):
134
  while not self.stop_event.is_set():
135
- time.sleep(0.1)
136
 
137
  if __name__ == "__main__":
138
  import argparse
 
10
  @dataclass
11
  class AudioStreamingClientArguments:
12
  sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
13
+ chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 1024."})
14
  api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
15
  auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})
16
 
 
26
  "Authorization": f"Bearer {self.args.auth_token}",
27
  "Content-Type": "application/json"
28
  }
29
+ self.session_state = "idle" # Possible states: idle, sending, processing, waiting
30
 
31
  def start(self):
32
  print("Starting audio streaming...")
33
 
34
  send_thread = threading.Thread(target=self.send_audio)
 
35
  play_thread = threading.Thread(target=self.play_audio)
36
 
37
+ with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback, blocksize=self.args.chunk_size):
38
  send_thread.start()
 
39
  play_thread.start()
40
 
41
  try:
 
45
  finally:
46
  self.stop_event.set()
47
  send_thread.join()
 
48
  play_thread.join()
49
  print("Audio streaming stopped.")
50
 
 
54
  def send_audio(self):
55
  buffer = b''
56
  while not self.stop_event.is_set():
57
+ if self.session_state != "processing" and not self.send_queue.empty():
58
  chunk = self.send_queue.get().tobytes()
59
  buffer += chunk
60
  if len(buffer) >= self.args.chunk_size * 2: # * 2 because of int16
61
  self.send_request(buffer)
62
  buffer = b''
63
  else:
64
+ self.send_request()
65
+ time.sleep(0.1)
66
+
67
+ def send_request(self, audio_data=None):
68
+ payload = {}
69
+
70
+ if audio_data is not None:
71
+ print("Sending audio data")
72
+ payload["inputs"] = base64.b64encode(audio_data).decode('utf-8')
73
+ payload["input_type"] = "speech"
74
 
75
+ if self.session_id:
76
+ payload["session_id"] = self.session_id
77
+ payload["request_type"] = "continue"
 
 
 
 
78
  else:
79
+ payload["request_type"] = "start"
 
 
 
 
80
 
81
  try:
82
  response = requests.post(self.args.api_url, headers=self.headers, json=payload)
 
85
  if "session_id" in response_data:
86
  self.session_id = response_data["session_id"]
87
 
88
+ if "status" in response_data and response_data["status"] == "processing":
89
+ print("Processing audio data")
90
+ self.session_state = "processing"
91
+ elif "status" in response_data and response_data["status"] == "completed":
92
+ print("Completed audio processing")
93
+ self.session_state = None
94
+ self.session_id = None
95
+ _ = self.send_queue.get() # Clear the queue
96
+
97
  if "output" in response_data and response_data["output"]:
98
+ print("Received audio data")
99
+ self.session_state = "processing" # Set state to processing when we start receiving audio
100
  audio_bytes = base64.b64decode(response_data["output"])
101
  audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
102
+ # Split the audio into smaller chunks for playback
103
+ for i in range(0, len(audio_np), self.args.chunk_size):
104
+ chunk = audio_np[i:i+self.args.chunk_size]
105
+ self.recv_queue.put(chunk)
106
 
107
  except Exception as e:
108
  print(f"Error sending request: {e}")
109
+ self.session_state = "idle" # Reset state to idle in case of error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def play_audio(self):
112
  def audio_callback(outdata, frames, time, status):
113
  if not self.recv_queue.empty():
114
  chunk = self.recv_queue.get()
115
+
116
+ # Ensure chunk is int16 and clip to valid range
117
+ chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)
118
+
119
+ if len(chunk_int16) < len(outdata):
120
+ outdata[:len(chunk_int16), 0] = chunk_int16
121
+ outdata[len(chunk_int16):] = 0
122
  else:
123
+ outdata[:, 0] = chunk_int16[:len(outdata)]
124
  else:
125
  outdata[:] = 0
126
 
127
+ with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=audio_callback, blocksize=self.args.chunk_size):
128
  while not self.stop_event.is_set():
129
+ time.sleep(0.01)
130
 
131
  if __name__ == "__main__":
132
  import argparse
audio_streaming_test.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from queue import Queue
3
+ import sounddevice as sd
4
+ import numpy as np
5
+ import requests
6
+ import base64
7
+ import time
8
+ from dataclasses import dataclass, field
9
+
10
+ @dataclass
11
+ class AudioStreamingClientArguments:
12
+ sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."})
13
+ chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 1024."})
14
+ api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."})
15
+ auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."})
16
+
17
+ class AudioStreamingClient:
18
+ def __init__(self, args: AudioStreamingClientArguments, handler):
19
+ self.args = args
20
+ self.handler = handler
21
+ self.stop_event = threading.Event()
22
+ self.send_queue = Queue()
23
+ self.recv_queue = Queue()
24
+ self.session_id = None
25
+ self.headers = {
26
+ "Accept": "application/json",
27
+ "Authorization": f"Bearer {self.args.auth_token}",
28
+ "Content-Type": "application/json"
29
+ }
30
+ self.session_state = "idle" # Possible states: idle, sending, processing, waiting
31
+
32
+ def start(self):
33
+ print("Starting audio streaming...")
34
+
35
+ send_thread = threading.Thread(target=self.send_audio)
36
+ play_thread = threading.Thread(target=self.play_audio)
37
+
38
+ with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback, blocksize=self.args.chunk_size):
39
+ send_thread.start()
40
+ play_thread.start()
41
+
42
+ try:
43
+ input("Press Enter to stop streaming...")
44
+ except KeyboardInterrupt:
45
+ print("\nStreaming interrupted by user.")
46
+ finally:
47
+ self.stop_event.set()
48
+ send_thread.join()
49
+ play_thread.join()
50
+ print("Audio streaming stopped.")
51
+
52
+ def audio_callback(self, indata, frames, time, status):
53
+ self.send_queue.put(indata.copy())
54
+
55
+ def send_audio(self):
56
+ buffer = b''
57
+ while not self.stop_event.is_set():
58
+ if self.session_state != "processing" and not self.send_queue.empty():
59
+ chunk = self.send_queue.get().tobytes()
60
+ buffer += chunk
61
+ if len(buffer) >= self.args.chunk_size * 2: # * 2 because of int16
62
+ self.send_request(buffer)
63
+ buffer = b''
64
+ else:
65
+ self.send_request()
66
+ time.sleep(0.1)
67
+
68
+ def send_request(self, audio_data=None):
69
+ payload = {}
70
+
71
+ if audio_data is not None:
72
+ print("Sending audio data")
73
+ payload["inputs"] = base64.b64encode(audio_data).decode('utf-8')
74
+ payload["input_type"] = "speech"
75
+
76
+ if self.session_id:
77
+ payload["session_id"] = self.session_id
78
+ payload["request_type"] = "continue"
79
+ else:
80
+ payload["request_type"] = "start"
81
+
82
+ try:
83
+ response_data = self.handler(payload)
84
+
85
+ if "session_id" in response_data:
86
+ self.session_id = response_data["session_id"]
87
+
88
+ if "status" in response_data and response_data["status"] == "processing":
89
+ print("Processing audio data")
90
+ self.session_state = "processing"
91
+ elif "status" in response_data and response_data["status"] == "completed":
92
+ print("Completed audio processing")
93
+ self.session_state = None
94
+ self.session_id = None
95
+ _ = self.send_queue.get() # Clear the queue
96
+
97
+ if "output" in response_data and response_data["output"]:
98
+ print("Received audio data")
99
+ self.session_state = "processing" # Set state to processing when we start receiving audio
100
+ audio_bytes = base64.b64decode(response_data["output"])
101
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
102
+ # Split the audio into smaller chunks for playback
103
+ for i in range(0, len(audio_np), self.args.chunk_size):
104
+ chunk = audio_np[i:i+self.args.chunk_size]
105
+ self.recv_queue.put(chunk)
106
+
107
+ except Exception as e:
108
+ print(f"Error sending request: {e}")
109
+ self.session_state = "idle" # Reset state to idle in case of error
110
+
111
+ def play_audio(self):
112
+ def audio_callback(outdata, frames, time, status):
113
+ if not self.recv_queue.empty():
114
+ chunk = self.recv_queue.get()
115
+
116
+ # Ensure chunk is int16 and clip to valid range
117
+ chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16)
118
+
119
+ if len(chunk_int16) < len(outdata):
120
+ outdata[:len(chunk_int16), 0] = chunk_int16
121
+ outdata[len(chunk_int16):] = 0
122
+ else:
123
+ outdata[:, 0] = chunk_int16[:len(outdata)]
124
+ else:
125
+ outdata[:] = 0
126
+
127
+ with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=audio_callback, blocksize=self.args.chunk_size):
128
+ while not self.stop_event.is_set():
129
+ time.sleep(0.01)
130
+
131
+ if __name__ == "__main__":
132
+ import argparse
133
+
134
+ parser = argparse.ArgumentParser(description="Audio Streaming Client")
135
+ parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.")
136
+ parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.")
137
+ parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.")
138
+ parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.")
139
+
140
+ args = parser.parse_args()
141
+ client_args = AudioStreamingClientArguments(**vars(args))
142
+ client = AudioStreamingClient(client_args)
143
+ client.start()
handler.py CHANGED
@@ -23,7 +23,7 @@ class EndpointHandler:
23
  self.parler_tts_handler_kwargs,
24
  self.melo_tts_handler_kwargs,
25
  self.chat_tts_handler_kwargs,
26
- ) = get_default_arguments(mode='none', log_level='DEBUG')
27
  setup_logger(self.module_kwargs.log_level)
28
 
29
  prepare_all_args(
@@ -59,6 +59,22 @@ class EndpointHandler:
59
  # Add a new queue for collecting the final output
60
  self.final_output_queue = Queue()
61
  self.sessions = {} # Store session information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def _collect_output(self, session_id):
64
  while True:
@@ -87,9 +103,10 @@ class EndpointHandler:
87
  def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
88
  session_id = str(uuid.uuid4())
89
  self.sessions[session_id] = {
90
- 'status': 'processing',
91
  'chunks': [],
92
- 'last_sent_index': 0
 
93
  }
94
 
95
  input_type = data.get("input_type", "text")
@@ -97,17 +114,16 @@ class EndpointHandler:
97
 
98
  if input_type == "speech":
99
  audio_bytes = base64.b64decode(input_data)
100
- audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
101
- self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
102
  elif input_type == "text":
103
  self.queues_and_events['text_prompt_queue'].put(input_data)
104
- else:
105
  raise ValueError(f"Unsupported input type: {input_type}")
106
 
107
  # Start output collection in a separate thread
108
  threading.Thread(target=self._collect_output, args=(session_id,)).start()
109
 
110
- return {"session_id": session_id, "status": "processing"}
111
 
112
  def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
  session_id = data.get("session_id")
@@ -116,12 +132,12 @@ class EndpointHandler:
116
 
117
  session = self.sessions[session_id]
118
 
119
- # Handle additional input if provided
120
- if "inputs" in data:
 
121
  input_data = data["inputs"]
122
  audio_bytes = base64.b64decode(input_data)
123
- audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
124
- self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
125
 
126
  chunks_to_send = session['chunks'][session['last_sent_index']:]
127
  session['last_sent_index'] = len(session['chunks'])
 
23
  self.parler_tts_handler_kwargs,
24
  self.melo_tts_handler_kwargs,
25
  self.chat_tts_handler_kwargs,
26
+ ) = get_default_arguments(mode='none', log_level='DEBUG', stt='whisper-mlx', tts='melo', device='mps')
27
  setup_logger(self.module_kwargs.log_level)
28
 
29
  prepare_all_args(
 
59
  # Add a new queue for collecting the final output
60
  self.final_output_queue = Queue()
61
  self.sessions = {} # Store session information
62
+ self.vad_chunk_size = 512 # Set the chunk size required by the VAD model
63
+ self.sample_rate = 16000 # Set the expected sample rate
64
+
65
+ def _process_audio_chunk(self, audio_data: bytes, session_id: str):
66
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
67
+
68
+ # Ensure the audio is in chunks of the correct size
69
+ chunks = [audio_array[i:i+self.vad_chunk_size] for i in range(0, len(audio_array), self.vad_chunk_size)]
70
+
71
+ for chunk in chunks:
72
+ if len(chunk) == self.vad_chunk_size:
73
+ self.queues_and_events['recv_audio_chunks_queue'].put(chunk.tobytes())
74
+ elif len(chunk) < self.vad_chunk_size:
75
+ # Pad the last chunk if it's smaller than the required size
76
+ padded_chunk = np.pad(chunk, (0, self.vad_chunk_size - len(chunk)), 'constant')
77
+ self.queues_and_events['recv_audio_chunks_queue'].put(padded_chunk.tobytes())
78
 
79
  def _collect_output(self, session_id):
80
  while True:
 
103
  def _handle_start_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
104
  session_id = str(uuid.uuid4())
105
  self.sessions[session_id] = {
106
+ 'status': 'new',
107
  'chunks': [],
108
+ 'last_sent_index': 0,
109
+ 'buffer': b'' # Add a buffer to store incomplete chunks
110
  }
111
 
112
  input_type = data.get("input_type", "text")
 
114
 
115
  if input_type == "speech":
116
  audio_bytes = base64.b64decode(input_data)
117
+ self._process_audio_chunk(audio_bytes, session_id)
 
118
  elif input_type == "text":
119
  self.queues_and_events['text_prompt_queue'].put(input_data)
120
+ else:
121
  raise ValueError(f"Unsupported input type: {input_type}")
122
 
123
  # Start output collection in a separate thread
124
  threading.Thread(target=self._collect_output, args=(session_id,)).start()
125
 
126
+ return {"session_id": session_id, "status": "new"}
127
 
128
  def _handle_continue_request(self, data: Dict[str, Any]) -> Dict[str, Any]:
129
  session_id = data.get("session_id")
 
132
 
133
  session = self.sessions[session_id]
134
 
135
+ if not self.queues_and_events['should_listen'].is_set():
136
+ session['status'] = 'processing'
137
+ elif "inputs" in data: # Handle additional input if provided
138
  input_data = data["inputs"]
139
  audio_bytes = base64.b64decode(input_data)
140
+ self._process_audio_chunk(audio_bytes, session_id)
 
141
 
142
  chunks_to_send = session['chunks'][session['last_sent_index']:]
143
  session['last_sent_index'] = len(session['chunks'])
test_audio_handler.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ from audio_streaming_test import AudioStreamingClientArguments, AudioStreamingClient
3
+
4
+ my_handler = EndpointHandler()
5
+
6
+ args = AudioStreamingClientArguments()
7
+
8
+ client = AudioStreamingClient(args, my_handler)
9
+
10
+ client.start()