IliaLarchenko commited on
Commit
6bb887d
·
1 Parent(s): da72dc0

Improved STT logic

Browse files
Files changed (5) hide show
  1. api/audio.py +46 -13
  2. app.py +1 -0
  3. requirements.txt +1 -0
  4. ui/coding.py +34 -31
  5. utils/ui.py +2 -1
api/audio.py CHANGED
@@ -8,6 +8,28 @@ from openai import OpenAI
8
 
9
  from utils.errors import APIError, AudioConversionError
10
  from typing import List, Dict, Optional, Generator, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class STTManager:
@@ -42,9 +64,7 @@ class STTManager:
42
  raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}")
43
  return buffer.getvalue()
44
 
45
- def process_audio_chunk(
46
- self, audio: Tuple[int, np.ndarray], audio_buffer: np.ndarray, transcript: Dict
47
- ) -> Tuple[Dict, np.ndarray, str]:
48
  """
49
  Process streamed audio data to accumulate and transcribe with overlapping segments.
50
 
@@ -53,15 +73,26 @@ class STTManager:
53
  :param transcript: Current transcript dictionary.
54
  :return: Updated transcript, updated audio buffer, and transcript text.
55
  """
56
- audio_buffer = np.concatenate((audio_buffer, audio[1]))
57
 
58
- if len(audio_buffer) >= self.SAMPLE_RATE * self.CHUNK_LENGTH or len(audio_buffer) % (self.SAMPLE_RATE // 2) != 0:
59
- audio_bytes = self.numpy_audio_to_bytes(audio_buffer[: self.SAMPLE_RATE * self.CHUNK_LENGTH])
60
- audio_buffer = audio_buffer[self.SAMPLE_RATE * self.STEP_LENGTH :]
61
- new_transcript = self.speech_to_text_stream(audio_bytes)
62
- transcript = self.merge_transcript(transcript, new_transcript)
 
 
 
 
 
 
 
63
 
64
- return transcript, audio_buffer, transcript["text"]
 
 
 
 
 
65
 
66
  def speech_to_text_stream(self, audio: bytes) -> List[Dict[str, str]]:
67
  """
@@ -114,19 +145,21 @@ class STTManager:
114
  transcript["text"] = " ".join(transcript["words"])
115
  return transcript
116
 
117
- def speech_to_text_full(self, audio: Tuple[int, np.ndarray]) -> str:
118
  """
119
  Convert speech to text from a full audio segment.
120
 
121
  :param audio: Tuple containing the sample rate and audio data as numpy array.
122
  :return: Transcribed text.
123
  """
124
- audio_bytes = self.numpy_audio_to_bytes(audio[1])
125
  try:
126
  if self.config.stt.type == "OPENAI_API":
127
  data = ("temp.wav", audio_bytes, "audio/wav")
128
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
129
- transcription = client.audio.transcriptions.create(model=self.config.stt.name, file=data, response_format="text")
 
 
130
  elif self.config.stt.type == "HF_API":
131
  headers = {"Authorization": "Bearer " + self.config.stt.key}
132
  response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
 
8
 
9
  from utils.errors import APIError, AudioConversionError
10
  from typing import List, Dict, Optional, Generator, Tuple
11
+ import webrtcvad
12
+
13
+
14
+ def detect_voice(audio: np.ndarray, sample_rate: int = 48000, frame_duration: int = 30) -> bool:
15
+ vad = webrtcvad.Vad()
16
+ vad.set_mode(3) # Aggressiveness mode: 0 (least aggressive) to 3 (most aggressive)
17
+
18
+ # Convert numpy array to 16-bit PCM bytes
19
+ audio_bytes = audio.tobytes()
20
+
21
+ num_samples_per_frame = int(sample_rate * frame_duration / 1000)
22
+ frames = [audio_bytes[i : i + num_samples_per_frame * 2] for i in range(0, len(audio_bytes), num_samples_per_frame * 2)]
23
+
24
+ count_speech = 0
25
+ for frame in frames:
26
+ if len(frame) < num_samples_per_frame * 2:
27
+ continue
28
+ if vad.is_speech(frame, sample_rate):
29
+ count_speech += 1
30
+ if count_speech > 6:
31
+ return True
32
+ return False
33
 
34
 
35
  class STTManager:
 
64
  raise AudioConversionError(f"Error converting numpy array to audio bytes: {e}")
65
  return buffer.getvalue()
66
 
67
+ def process_audio_chunk(self, audio: Tuple[int, np.ndarray], audio_buffer: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
 
 
68
  """
69
  Process streamed audio data to accumulate and transcribe with overlapping segments.
70
 
 
73
  :param transcript: Current transcript dictionary.
74
  :return: Updated transcript, updated audio buffer, and transcript text.
75
  """
 
76
 
77
+ has_voice = detect_voice(audio[1])
78
+ ended = len(audio[1]) % 24000 != 0
79
+
80
+ if has_voice:
81
+ audio_buffer = np.concatenate((audio_buffer, audio[1]))
82
+
83
+ is_short = len(audio_buffer) / 48000 < 1.0
84
+
85
+ if is_short or (has_voice and not ended):
86
+ return audio_buffer, np.array([], dtype=np.int16)
87
+
88
+ return np.array([], dtype=np.int16), audio_buffer
89
 
90
+ def transcribe_audio(self, audio: np.ndarray, text) -> str:
91
+ if len(audio) < 500:
92
+ return text
93
+ else:
94
+ transcript = self.transcribe_numpy_array(audio, context=text)
95
+ return text + " " + transcript
96
 
97
  def speech_to_text_stream(self, audio: bytes) -> List[Dict[str, str]]:
98
  """
 
145
  transcript["text"] = " ".join(transcript["words"])
146
  return transcript
147
 
148
+ def transcribe_numpy_array(self, audio: np.ndarray, context: Optional[str] = None) -> str:
149
  """
150
  Convert speech to text from a full audio segment.
151
 
152
  :param audio: Tuple containing the sample rate and audio data as numpy array.
153
  :return: Transcribed text.
154
  """
155
+ audio_bytes = self.numpy_audio_to_bytes(audio)
156
  try:
157
  if self.config.stt.type == "OPENAI_API":
158
  data = ("temp.wav", audio_bytes, "audio/wav")
159
  client = OpenAI(base_url=self.config.stt.url, api_key=self.config.stt.key)
160
+ transcription = client.audio.transcriptions.create(
161
+ model=self.config.stt.name, file=data, response_format="text", prompt=context
162
+ )
163
  elif self.config.stt.type == "HF_API":
164
  headers = {"Authorization": "Bearer " + self.config.stt.key}
165
  response = requests.post(self.config.stt.url, headers=headers, data=audio_bytes)
app.py CHANGED
@@ -35,6 +35,7 @@ def main():
35
  """Main function to initialize services and launch the Gradio interface."""
36
  config, llm, tts, stt = initialize_services()
37
  demo = create_interface(llm, tts, stt, default_audio_params)
 
38
  demo.launch(show_api=False)
39
 
40
 
 
35
  """Main function to initialize services and launch the Gradio interface."""
36
  config, llm, tts, stt = initialize_services()
37
  demo = create_interface(llm, tts, stt, default_audio_params)
38
+ demo.config["dependencies"][0]["show_progress"] = "hidden"
39
  demo.launch(show_api=False)
40
 
41
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ gradio==4.29.0
2
  openai==1.19.0
3
  python-dotenv==1.0.1
4
  pytest==8.2.0
 
 
2
  openai==1.19.0
3
  python-dotenv==1.0.1
4
  pytest==8.2.0
5
+ webrtcvad=2.0.10
ui/coding.py CHANGED
@@ -3,11 +3,14 @@ import numpy as np
3
  import os
4
 
5
  from itertools import chain
 
6
 
7
  from resources.data import fixed_messages, topic_lists
8
  from utils.ui import add_candidate_message, add_interviewer_message
9
  from typing import List, Dict, Generator, Optional, Tuple
10
  from functools import partial
 
 
11
 
12
 
13
  def send_request(
@@ -15,8 +18,8 @@ def send_request(
15
  previous_code: str,
16
  chat_history: List[Dict[str, str]],
17
  chat_display: List[List[Optional[str]]],
18
- llm,
19
- tts,
20
  silent: Optional[bool] = False,
21
  ) -> Generator[Tuple[List[Dict[str, str]], List[List[Optional[str]]], str, bytes], None, None]:
22
  """
@@ -26,14 +29,19 @@ def send_request(
26
  if silent is None:
27
  silent = os.getenv("SILENT", False)
28
 
 
 
 
 
29
  chat_history = llm.update_chat_history(code, previous_code, chat_history, chat_display)
30
  original_len = len(chat_display)
31
  chat_display.append([None, ""])
32
- chat_history.append({"role": "assistant", "content": ""})
33
 
34
  text_chunks = []
35
  reply = llm.get_text(chat_history)
36
 
 
 
37
  audio_generator = iter(())
38
  has_text_item = True
39
  has_audio_item = not silent
@@ -99,7 +107,7 @@ def change_code_area(interview_type):
99
  )
100
 
101
 
102
- def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output):
103
  send_request_partial = partial(send_request, llm=llm, tts=tts)
104
 
105
  with gr.Tab("Interview", render=False, elem_id=f"tab") as problem_tab:
@@ -178,20 +186,22 @@ def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output):
178
  with gr.Column(scale=1):
179
  end_btn = gr.Button("Finish the interview", interactive=False, variant="stop", elem_id=f"end_btn")
180
  chat = gr.Chatbot(label="Chat", show_label=False, show_share_button=False, elem_id=f"chat")
 
 
 
181
  message = gr.Textbox(
182
  label="Message",
183
  show_label=False,
184
- lines=3,
185
- max_lines=3,
186
- interactive=True,
187
  container=False,
188
  elem_id=f"message",
189
  )
190
- send_btn = gr.Button("Send", interactive=False, elem_id=f"send_btn")
191
- audio_input = gr.Audio(interactive=False, **default_audio_params, elem_id=f"audio_input")
192
 
 
193
  audio_buffer = gr.State(np.array([], dtype=np.int16))
194
- transcript = gr.State({"words": [], "not_confirmed": 0, "last_cutoff": 0, "text": ""})
195
 
196
  with gr.Accordion("Feedback", open=True, visible=False) as feedback_acc:
197
  feedback = gr.Markdown(elem_id=f"feedback", line_breaks=True)
@@ -219,8 +229,8 @@ def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output):
219
  ).success(
220
  fn=llm.init_bot, inputs=[description, interview_type_select], outputs=[chat_history]
221
  ).success(
222
- fn=lambda: (gr.update(visible=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)),
223
- outputs=[solution_acc, end_btn, audio_input, send_btn],
224
  )
225
 
226
  end_btn.click(fn=lambda x: add_candidate_message("Let's stop here.", x), inputs=[chat], outputs=[chat]).success(
@@ -233,9 +243,8 @@ def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output):
233
  gr.update(interactive=False),
234
  gr.update(open=False),
235
  gr.update(interactive=False),
236
- gr.update(interactive=False),
237
  ),
238
- outputs=[solution_acc, end_btn, problem_acc, audio_input, send_btn],
239
  ).success(
240
  fn=lambda: (gr.update(visible=True)),
241
  outputs=[feedback_acc],
@@ -243,32 +252,26 @@ def get_problem_solving_ui(llm, tts, stt, default_audio_params, audio_output):
243
  fn=llm.end_interview, inputs=[description, chat_history, interview_type_select], outputs=[feedback]
244
  )
245
 
246
- send_btn.click(fn=add_candidate_message, inputs=[message, chat], outputs=[chat]).success(
247
- fn=lambda: None, outputs=[message]
 
 
 
 
 
 
 
 
248
  ).success(
249
  fn=send_request_partial,
250
  inputs=[code, previous_code, chat_history, chat],
251
  outputs=[chat_history, chat, previous_code, audio_output],
252
- # ).success(
253
- # fn=tts.read_last_message, inputs=[chat], outputs=[audio_output]
254
  ).success(
255
  fn=lambda: np.array([], dtype=np.int16), outputs=[audio_buffer]
256
  ).success(
257
- fn=lambda: {"words": [], "not_confirmed": 0, "last_cutoff": 0, "text": ""}, outputs=[transcript]
258
  )
259
 
260
- if stt.streaming:
261
- audio_input.stream(
262
- stt.process_audio_chunk,
263
- inputs=[audio_input, audio_buffer, transcript],
264
- outputs=[transcript, audio_buffer, message],
265
- show_progress="hidden",
266
- )
267
- else:
268
- audio_input.stop_recording(fn=stt.speech_to_text_full, inputs=[audio_input], outputs=[message]).success(
269
- fn=lambda: gr.update(interactive=True), outputs=[send_btn]
270
- ).success(fn=lambda: None, outputs=[audio_input])
271
-
272
  interview_type_select.change(
273
  fn=lambda x: gr.update(choices=topic_lists[x], value=np.random.choice(topic_lists[x])),
274
  inputs=[interview_type_select],
 
3
  import os
4
 
5
  from itertools import chain
6
+ import time
7
 
8
  from resources.data import fixed_messages, topic_lists
9
  from utils.ui import add_candidate_message, add_interviewer_message
10
  from typing import List, Dict, Generator, Optional, Tuple
11
  from functools import partial
12
+ from api.llm import LLMManager
13
+ from api.audio import TTSManager, STTManager
14
 
15
 
16
  def send_request(
 
18
  previous_code: str,
19
  chat_history: List[Dict[str, str]],
20
  chat_display: List[List[Optional[str]]],
21
+ llm: LLMManager,
22
+ tts: Optional[TTSManager],
23
  silent: Optional[bool] = False,
24
  ) -> Generator[Tuple[List[Dict[str, str]], List[List[Optional[str]]], str, bytes], None, None]:
25
  """
 
29
  if silent is None:
30
  silent = os.getenv("SILENT", False)
31
 
32
+ if chat_display[-1][0] is None and code == previous_code:
33
+ yield chat_history, chat_display, code, b""
34
+ return
35
+
36
  chat_history = llm.update_chat_history(code, previous_code, chat_history, chat_display)
37
  original_len = len(chat_display)
38
  chat_display.append([None, ""])
 
39
 
40
  text_chunks = []
41
  reply = llm.get_text(chat_history)
42
 
43
+ chat_history.append({"role": "assistant", "content": ""})
44
+
45
  audio_generator = iter(())
46
  has_text_item = True
47
  has_audio_item = not silent
 
107
  )
108
 
109
 
110
+ def get_problem_solving_ui(llm: LLMManager, tts: TTSManager, stt: STTManager, default_audio_params: Dict, audio_output):
111
  send_request_partial = partial(send_request, llm=llm, tts=tts)
112
 
113
  with gr.Tab("Interview", render=False, elem_id=f"tab") as problem_tab:
 
186
  with gr.Column(scale=1):
187
  end_btn = gr.Button("Finish the interview", interactive=False, variant="stop", elem_id=f"end_btn")
188
  chat = gr.Chatbot(label="Chat", show_label=False, show_share_button=False, elem_id=f"chat")
189
+
190
+ # I need this message box only because chat component is flickering when I am updating it
191
+ # To be improved in the future
192
  message = gr.Textbox(
193
  label="Message",
194
  show_label=False,
195
+ lines=5,
196
+ max_lines=5,
197
+ interactive=False,
198
  container=False,
199
  elem_id=f"message",
200
  )
 
 
201
 
202
+ audio_input = gr.Audio(interactive=False, **default_audio_params, elem_id=f"audio_input")
203
  audio_buffer = gr.State(np.array([], dtype=np.int16))
204
+ audio_to_transcribe = gr.State(np.array([], dtype=np.int16))
205
 
206
  with gr.Accordion("Feedback", open=True, visible=False) as feedback_acc:
207
  feedback = gr.Markdown(elem_id=f"feedback", line_breaks=True)
 
229
  ).success(
230
  fn=llm.init_bot, inputs=[description, interview_type_select], outputs=[chat_history]
231
  ).success(
232
+ fn=lambda: (gr.update(visible=True), gr.update(interactive=True), gr.update(interactive=True)),
233
+ outputs=[solution_acc, end_btn, audio_input],
234
  )
235
 
236
  end_btn.click(fn=lambda x: add_candidate_message("Let's stop here.", x), inputs=[chat], outputs=[chat]).success(
 
243
  gr.update(interactive=False),
244
  gr.update(open=False),
245
  gr.update(interactive=False),
 
246
  ),
247
+ outputs=[solution_acc, end_btn, problem_acc, audio_input],
248
  ).success(
249
  fn=lambda: (gr.update(visible=True)),
250
  outputs=[feedback_acc],
 
252
  fn=llm.end_interview, inputs=[description, chat_history, interview_type_select], outputs=[feedback]
253
  )
254
 
255
+ audio_input.stream(
256
+ stt.process_audio_chunk,
257
+ inputs=[audio_input, audio_buffer],
258
+ outputs=[audio_buffer, audio_to_transcribe],
259
+ show_progress="hidden",
260
+ ).success(fn=stt.transcribe_audio, inputs=[audio_to_transcribe, message], outputs=[message], show_progress="hidden")
261
+
262
+ # TODO: find a way to remove delay
263
+ audio_input.stop_recording(fn=lambda: time.sleep(2)).success(
264
+ fn=add_candidate_message, inputs=[message, chat], outputs=[chat]
265
  ).success(
266
  fn=send_request_partial,
267
  inputs=[code, previous_code, chat_history, chat],
268
  outputs=[chat_history, chat, previous_code, audio_output],
 
 
269
  ).success(
270
  fn=lambda: np.array([], dtype=np.int16), outputs=[audio_buffer]
271
  ).success(
272
+ lambda: "", outputs=[message]
273
  )
274
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  interview_type_select.change(
276
  fn=lambda x: gr.update(choices=topic_lists[x], value=np.random.choice(topic_lists[x])),
277
  inputs=[interview_type_select],
utils/ui.py CHANGED
@@ -8,7 +8,8 @@ def add_interviewer_message(message):
8
 
9
 
10
  def add_candidate_message(message, chat):
11
- chat.append((message, None))
 
12
  return chat
13
 
14
 
 
8
 
9
 
10
  def add_candidate_message(message, chat):
11
+ if message and len(message) > 0:
12
+ chat.append((message, None))
13
  return chat
14
 
15