SpicyqSama007 commited on
Commit
7a0182a
·
verified ·
1 Parent(s): f9b6bed

Streaming agent

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -2,6 +2,8 @@ import re
2
  import gradio as gr
3
  import numpy as np
4
  import os
 
 
5
  import threading
6
  import subprocess
7
  import sys
@@ -52,6 +54,17 @@ class ChatState:
52
  def clear_fn():
53
  return [], ChatState(), None, None, None
54
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  async def process_audio_input(
57
  sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
@@ -72,11 +85,9 @@ async def process_audio_input(
72
 
73
  if isinstance(sys_audio_input, tuple):
74
  sr, sys_audio_data = sys_audio_input
75
- elif text_input:
76
  sr = 44100
77
  sys_audio_data = None
78
- else:
79
- raise gr.Error("Invalid audio format")
80
 
81
  def append_to_chat_ctx(
82
  part: ServeTextPart | ServeVQPart, role: str = "assistant"
@@ -106,22 +117,16 @@ async def process_audio_input(
106
  ):
107
  if event.type == FishE2EEventType.USER_CODES:
108
  append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
 
109
  elif event.type == FishE2EEventType.SPEECH_SEGMENT:
110
- result_audio += event.frame.data
111
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
112
  append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
113
-
114
- yield state.get_history(), (44100, np_audio), None, None
115
  elif event.type == FishE2EEventType.TEXT_SEGMENT:
116
  append_to_chat_ctx(ServeTextPart(text=event.text))
117
- if result_audio:
118
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
119
- yield state.get_history(), (44100, np_audio), None, None
120
- else:
121
- yield state.get_history(), None, None, None
122
-
123
- np_audio = np.frombuffer(result_audio, dtype=np.int16)
124
- yield state.get_history(), (44100, np_audio), None, None
125
 
126
 
127
  async def process_text_input(
 
2
  import gradio as gr
3
  import numpy as np
4
  import os
5
+ import io
6
+ import wave
7
  import threading
8
  import subprocess
9
  import sys
 
54
  def clear_fn():
55
  return [], ChatState(), None, None, None
56
 
57
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
58
+ buffer = io.BytesIO()
59
+
60
+ with wave.open(buffer, "wb") as wav_file:
61
+ wav_file.setnchannels(channels)
62
+ wav_file.setsampwidth(bit_depth // 8)
63
+ wav_file.setframerate(sample_rate)
64
+
65
+ wav_header_bytes = buffer.getvalue()
66
+ buffer.close()
67
+ return wav_header_bytes
68
 
69
  async def process_audio_input(
70
  sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
 
85
 
86
  if isinstance(sys_audio_input, tuple):
87
  sr, sys_audio_data = sys_audio_input
88
+ else:
89
  sr = 44100
90
  sys_audio_data = None
 
 
91
 
92
  def append_to_chat_ctx(
93
  part: ServeTextPart | ServeVQPart, role: str = "assistant"
 
117
  ):
118
  if event.type == FishE2EEventType.USER_CODES:
119
  append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
120
+
121
  elif event.type == FishE2EEventType.SPEECH_SEGMENT:
 
 
122
  append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
123
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
124
+
125
  elif event.type == FishE2EEventType.TEXT_SEGMENT:
126
  append_to_chat_ctx(ServeTextPart(text=event.text))
127
+ yield state.get_history(), None, None, None
128
+
129
+ yield state.get_history(), None, None, None
 
 
 
 
 
130
 
131
 
132
  async def process_text_input(