DSatishchandra commited on
Commit
6892821
·
verified ·
1 Parent(s): b95d9c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -59
app.py CHANGED
@@ -1,61 +1,204 @@
1
  import gradio as gr
2
- import speech_recognition as sr
3
- import pyttsx3
4
- from transformers import pipeline
5
-
6
- # Initialize the text-to-speech engine
7
- engine = pyttsx3.init()
8
-
9
- # Initialize the transformer pipeline for NLP (Text Classification or any specific task)
10
- nlp = pipeline("zero-shot-classification")
11
-
12
- # Function to convert speech to text
13
- def speech_to_text(audio_file):
14
- recognizer = sr.Recognizer()
15
- with sr.AudioFile(audio_file.name) as source:
16
- audio = recognizer.record(source)
17
- try:
18
- text = recognizer.recognize_google(audio)
19
- return text
20
- except sr.UnknownValueError:
21
- return "Sorry, I didn't catch that."
22
- except sr.RequestError:
23
- return "Sorry, there's an issue with the speech recognition service."
24
-
25
- # Function to process text (handle menu ordering)
26
- def process_order(text):
27
- # You can add your logic here for handling various food orders and preferences
28
- result = nlp(text, candidate_labels=["Vegan", "Halal", "Guilt-Free", "Regular"])
29
- category = result['labels'][0]
30
-
31
- if "Vegan" in category:
32
- response = "You've chosen a Vegan dish."
33
- elif "Halal" in category:
34
- response = "You've chosen a Halal dish."
35
- elif "Guilt-Free" in category:
36
- response = "You've chosen a Guilt-Free dish."
37
- else:
38
- response = "You've chosen a regular dish."
39
-
40
- return response
41
-
42
- # Function for Text-to-Speech (Response back to user)
43
- def speak_response(text):
44
- engine.say(text)
45
- engine.runAndWait()
46
-
47
- # Create Gradio interface
48
- def voice_assistant(audio_file):
49
- text = speech_to_text(audio_file)
50
- response = process_order(text)
51
- speak_response(response)
52
- return response
53
-
54
- iface = gr.Interface(fn=voice_assistant,
55
- inputs=gr.inputs.Audio(source="microphone", type="file"),
56
- outputs="text",
57
- live=True)
58
-
59
- # Launch Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
- iface.launch()
 
 
1
  import gradio as gr
2
+ from gradio_webrtc import WebRTC, StreamHandler, get_twilio_turn_credentials
3
+ import websockets.sync.client
4
+ import numpy as np
5
+ import json
6
+ import base64
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ class GeminiConfig:
11
+ def __init__(self):
12
+ load_dotenv()
13
+ self.api_key = self._get_api_key()
14
+ self.host = 'generativelanguage.googleapis.com'
15
+ self.model = 'models/gemini-2.0-flash-exp'
16
+ self.ws_url = f'wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}'
17
+
18
+ def _get_api_key(self):
19
+ api_key = os.getenv('GOOGLE_API_KEY')
20
+ if not api_key:
21
+ raise ValueError("GOOGLE_API_KEY not found in environment variables. Please set it in your .env file.")
22
+ return api_key
23
+
24
+ class AudioProcessor:
25
+ @staticmethod
26
+ def encode_audio(data, sample_rate):
27
+ encoded = base64.b64encode(data.tobytes()).decode('UTF-8')
28
+ return {
29
+ 'realtimeInput': {
30
+ 'mediaChunks': [{
31
+ 'mimeType': f'audio/pcm;rate={sample_rate}',
32
+ 'data': encoded,
33
+ }],
34
+ },
35
+ }
36
+
37
+ @staticmethod
38
+ def process_audio_response(data):
39
+ audio_data = base64.b64decode(data)
40
+ return np.frombuffer(audio_data, dtype=np.int16)
41
+
42
+ class GeminiHandler(StreamHandler):
43
+ def __init__(self,
44
+ expected_layout="mono",
45
+ output_sample_rate=24000,
46
+ output_frame_size=480) -> None:
47
+ super().__init__(expected_layout, output_sample_rate, output_frame_size,
48
+ input_sample_rate=24000)
49
+ self.config = GeminiConfig()
50
+ self.ws = None
51
+ self.all_output_data = None
52
+ self.audio_processor = AudioProcessor()
53
+
54
+ def copy(self):
55
+ return GeminiHandler(
56
+ expected_layout=self.expected_layout,
57
+ output_sample_rate=self.output_sample_rate,
58
+ output_frame_size=self.output_frame_size
59
+ )
60
+
61
+ def _initialize_websocket(self):
62
+ try:
63
+ self.ws = websockets.sync.client.connect(
64
+ self.config.ws_url,
65
+ timeout=30
66
+ )
67
+ initial_request = {
68
+ 'setup': {
69
+ 'model': self.config.model,
70
+ }
71
+ }
72
+ self.ws.send(json.dumps(initial_request))
73
+ setup_response = json.loads(self.ws.recv())
74
+ print(f"Setup response: {setup_response}")
75
+ except websockets.exceptions.WebSocketException as e:
76
+ print(f"WebSocket connection failed: {str(e)}")
77
+ self.ws = None
78
+ except Exception as e:
79
+ print(f"Setup failed: {str(e)}")
80
+ self.ws = None
81
+
82
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
83
+ try:
84
+ if not self.ws:
85
+ self._initialize_websocket()
86
+
87
+ _, array = frame
88
+ array = array.squeeze()
89
+ audio_message = self.audio_processor.encode_audio(array, self.output_sample_rate)
90
+ self.ws.send(json.dumps(audio_message))
91
+ except Exception as e:
92
+ print(f"Error in receive: {str(e)}")
93
+ if self.ws:
94
+ self.ws.close()
95
+ self.ws = None
96
+
97
+ def _process_server_content(self, content):
98
+ for part in content.get('parts', []):
99
+ data = part.get('inlineData', {}).get('data', '')
100
+ if data:
101
+ audio_array = self.audio_processor.process_audio_response(data)
102
+ if self.all_output_data is None:
103
+ self.all_output_data = audio_array
104
+ else:
105
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
106
+
107
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
108
+ yield (self.output_sample_rate,
109
+ self.all_output_data[:self.output_frame_size].reshape(1, -1))
110
+ self.all_output_data = self.all_output_data[self.output_frame_size:]
111
+
112
+ def generator(self):
113
+ while True:
114
+ if not self.ws:
115
+ print("WebSocket not connected")
116
+ yield None
117
+ continue
118
+
119
+ try:
120
+ message = self.ws.recv(timeout=5)
121
+ msg = json.loads(message)
122
+
123
+ if 'serverContent' in msg:
124
+ content = msg['serverContent'].get('modelTurn', {})
125
+ yield from self._process_server_content(content)
126
+ except TimeoutError:
127
+ print("Timeout waiting for server response")
128
+ yield None
129
+ except Exception as e:
130
+ print(f"Error in generator: {str(e)}")
131
+ yield None
132
+
133
+ def emit(self) -> tuple[int, np.ndarray] | None:
134
+ if not self.ws:
135
+ return None
136
+ if not hasattr(self, '_generator'):
137
+ self._generator = self.generator()
138
+ try:
139
+ return next(self._generator)
140
+ except StopIteration:
141
+ self.reset()
142
+ return None
143
+
144
+ def reset(self) -> None:
145
+ if hasattr(self, '_generator'):
146
+ delattr(self, '_generator')
147
+ self.all_output_data = None
148
+
149
+ def shutdown(self) -> None:
150
+ if self.ws:
151
+ self.ws.close()
152
+
153
+ def check_connection(self):
154
+ try:
155
+ if not self.ws or self.ws.closed:
156
+ self._initialize_websocket()
157
+ return True
158
+ except Exception as e:
159
+ print(f"Connection check failed: {str(e)}")
160
+ return False
161
+
162
+ class GeminiVoiceChat:
163
+ def __init__(self):
164
+ load_dotenv()
165
+ self.demo = self._create_interface()
166
+
167
+ def _create_interface(self):
168
+ with gr.Blocks() as demo:
169
+ gr.HTML("""
170
+ <div style='text-align: center'>
171
+ <h1>Gemini 2.0 Voice Chat</h1>
172
+ <p>Speak with Gemini using real-time audio streaming</p>
173
+ </div>
174
+ """)
175
+
176
+ webrtc = WebRTC(
177
+ label="Conversation",
178
+ modality="audio",
179
+ mode="send-receive",
180
+ rtc_configuration=get_twilio_turn_credentials()
181
+ )
182
+
183
+ webrtc.stream(
184
+ GeminiHandler(),
185
+ inputs=[webrtc],
186
+ outputs=[webrtc],
187
+ time_limit=90,
188
+ concurrency_limit=10
189
+ )
190
+ return demo
191
+
192
+ def launch(self):
193
+ self.demo.launch(
194
+ server_name="0.0.0.0",
195
+ server_port=int(os.environ.get("PORT", 7860)),
196
+ share=True,
197
+ ssl_verify=False,
198
+ ssl_keyfile=None,
199
+ ssl_certfile=None
200
+ )
201
+
202
  if __name__ == "__main__":
203
+ app = GeminiVoiceChat()
204
+ app.launch()