freddyaboulton HF staff commited on
Commit
405ac70
·
1 Parent(s): b4cd421
Files changed (2) hide show
  1. app.py +167 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from google import genai
10
+ from gradio_webrtc import (
11
+ AsyncAudioVideoStreamHandler,
12
+ WebRTC,
13
+ async_aggregate_bytes_to_16bit,
14
+ VideoEmitType,
15
+ AudioEmitType
16
+ )
17
+ from PIL import Image
18
+
19
+
20
+ def encode_audio(data: np.ndarray) -> dict:
21
+ """Encode Audio data to send to the server"""
22
+ return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
23
+
24
+
25
+ def encode_image(data: np.ndarray) -> dict:
26
+ with BytesIO() as output_bytes:
27
+ pil_image = Image.fromarray(data)
28
+ pil_image.save(output_bytes, "JPEG")
29
+ bytes_data = output_bytes.getvalue()
30
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
31
+ return {"mime_type": "image/jpeg", "data": base64_str}
32
+
33
+
34
+ class GeminiHandler(AsyncAudioVideoStreamHandler):
35
+ def __init__(
36
+ self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
37
+ ) -> None:
38
+ super().__init__(
39
+ expected_layout,
40
+ output_sample_rate,
41
+ output_frame_size,
42
+ input_sample_rate=16000,
43
+ )
44
+ self.audio_queue = asyncio.Queue()
45
+ self.video_queue = asyncio.Queue()
46
+ self.quit = asyncio.Event()
47
+ self.session = None
48
+ self.last_frame_time = 0
49
+
50
+ def copy(self) -> "GeminiHandler":
51
+ return GeminiHandler(
52
+ expected_layout=self.expected_layout,
53
+ output_sample_rate=self.output_sample_rate,
54
+ output_frame_size=self.output_frame_size,
55
+ )
56
+
57
+ async def video_receive(self, frame: np.ndarray):
58
+ if self.session:
59
+ # send image every 1 second
60
+ if time.time() - self.last_frame_time > 1:
61
+ self.last_frame_time = time.time()
62
+ await self.session.send(encode_image(frame))
63
+ if self.latest_args[2] is not None:
64
+ await self.session.send(encode_image(self.latest_args[2]))
65
+ self.video_queue.put_nowait(frame)
66
+
67
+ async def video_emit(self) -> VideoEmitType:
68
+ return await self.video_queue.get()
69
+
70
+ async def connect(self, api_key: str):
71
+ if self.session is None:
72
+ client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
73
+ config = {"response_modalities": ["AUDIO"]}
74
+ async with client.aio.live.connect(
75
+ model="gemini-2.0-flash-exp", config=config
76
+ ) as session:
77
+ self.session = session
78
+ asyncio.create_task(self.receive_audio())
79
+ await self.quit.wait()
80
+
81
+ async def generator(self):
82
+ while not self.quit.is_set():
83
+ turn = self.session.receive()
84
+ async for response in turn:
85
+ if data := response.data:
86
+ yield data
87
+
88
+ async def receive_audio(self):
89
+ async for audio_response in async_aggregate_bytes_to_16bit(
90
+ self.generator()
91
+ ):
92
+ self.audio_queue.put_nowait(audio_response)
93
+
94
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
95
+ _, array = frame
96
+ array = array.squeeze()
97
+ audio_message = encode_audio(array)
98
+ if self.session:
99
+ await self.session.send(audio_message)
100
+
101
+ async def emit(self) -> AudioEmitType:
102
+ if not self.args_set.is_set():
103
+ await self.wait_for_args()
104
+ if self.session is None:
105
+ asyncio.create_task(self.connect(self.latest_args[1]))
106
+ array = await self.audio_queue.get()
107
+ return (self.output_sample_rate, array)
108
+
109
+ def shutdown(self) -> None:
110
+ self.quit.set()
111
+
112
+
113
+ css = """
114
+ #video-source {max-width: 600px !important; max-height: 600 !important;}
115
+ """
116
+
117
+ with gr.Blocks(css=css) as demo:
118
+ gr.HTML(
119
+ """
120
+ <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
121
+ <div style="background-color: var(--block-background-fill); border-radius: 8px">
122
+ <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
123
+ </div>
124
+ <div>
125
+ <h1>Gen AI SDK Voice Chat</h1>
126
+ <p>Speak with Gemini using real-time audio + video streaming</p>
127
+ <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
128
+ <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
129
+ </div>
130
+ </div>
131
+ """
132
+ )
133
+ with gr.Row() as api_key_row:
134
+ api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API Key", value=os.getenv("GOOGLE_API_KEY"))
135
+ with gr.Row(visible=False) as row:
136
+ with gr.Column():
137
+ webrtc = WebRTC(
138
+ label="Video Chat",
139
+ modality="audio-video",
140
+ mode="send-receive",
141
+ elem_id="video-source",
142
+ # See for changes needed to deploy behind a firewall
143
+ # https://freddyaboulton.github.io/gradio-webrtc/deployment/
144
+ rtc_configuration=None,
145
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
146
+ pulse_color="rgb(35, 157, 225)",
147
+ icon_button_color="rgb(35, 157, 225)",
148
+ )
149
+ with gr.Column():
150
+ image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
151
+
152
+ webrtc.stream(
153
+ GeminiHandler(),
154
+ inputs=[webrtc, api_key, image_input],
155
+ outputs=[webrtc],
156
+ time_limit=90,
157
+ concurrency_limit=2,
158
+ )
159
+ api_key.submit(
160
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
161
+ None,
162
+ [api_key_row, row],
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio_webrtc==0.0.28
2
+ google-genai==0.3.0