Spaces:
Running
Running
File size: 4,677 Bytes
a461509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import Literal
import gradio as gr
from gradio_webrtc import WebRTC, StreamHandler, AdditionalOutputs
from numpy import ndarray
import sphn
import websockets.sync.client
import numpy as np
class MoshiHandler(StreamHandler):
def __init__(self,
url: str,
expected_layout: Literal['mono', 'stereo'] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480) -> None:
self.url = url
proto, without_proto = self.url.split('://', 1)
if proto in ['ws', 'http']:
proto = "ws"
elif proto in ['wss', 'https']:
proto = "wss"
self._generator = None
self.output_chunk_size = 1920
self.ws = None
self.ws_url = f"{proto}://{without_proto}/api/chat"
self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
self.all_output_data = None
super().__init__(expected_layout, output_sample_rate, output_frame_size,
input_sample_rate=24000)
def receive(self, frame: tuple[int, ndarray]) -> None:
if not self.ws:
self.ws = websockets.sync.client.connect(self.ws_url)
_, array = frame
array = array.squeeze().astype(np.float32) / 32768.0
self.stream_writer.append_pcm(array)
bytes = b"\x01" + self.stream_writer.read_bytes()
self.ws.send(bytes)
def generator(self):
for message in self.ws:
if len(message) == 0:
yield None
kind = message[0]
if kind == 1:
payload = message[1:]
self.stream_reader.append_bytes(payload)
pcm = self.stream_reader.read_pcm()
if self.all_output_data is None:
self.all_output_data = pcm
else:
self.all_output_data = np.concatenate((self.all_output_data, pcm))
while self.all_output_data.shape[-1] >= self.output_chunk_size:
yield (self.output_sample_rate, self.all_output_data[: self.output_chunk_size].reshape(1, -1))
self.all_output_data = np.array(self.all_output_data[self.output_chunk_size :])
elif kind == 2:
payload = message[1:]
yield AdditionalOutputs(payload.decode())
def emit(self) -> tuple[int, ndarray] | None:
if not self.ws:
return
if not self._generator:
self._generator = self.generator()
try:
return next(self._generator)
except StopIteration:
self.reset()
def reset(self) -> None:
self._generator = None
self.all_output_data = None
def copy(self) -> StreamHandler:
return MoshiHandler(self.url,
self.expected_layout,
self.output_sample_rate, self.output_frame_size)
def shutdown(self) -> None:
if self.ws:
self.ws.close()
with gr.Blocks() as demo:
gr.HTML(
"""
<div style='text-align: center'>
<h1>
Talk To Moshi (Powered by WebRTC ⚡️)
</h1>
<p>
Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
</p>
</div>
"""
)
response = gr.State(value="")
chatbot = gr.Chatbot(type="messages", value=[])
webrtc = WebRTC(label="Conversation", modality="audio", mode="send-receive", rtc_configuration=None)
webrtc.stream(MoshiHandler("https://freddyaboulton-moshi-server.hf.space"),
inputs=[webrtc, chatbot], outputs=[webrtc], time_limit=90)
def on_text(state, response):
print("response", response)
return state + response
def add_text(chat_history, response):
if len(chat_history) == 0:
chat_history.append({"role": "assistant", "content": response})
else:
chat_history[-1]["content"] = response
return chat_history
webrtc.on_additional_outputs(on_text,
inputs=[response], outputs=response,
queue=True,
show_progress="hidden")
response.change(add_text, inputs=[chatbot, response], outputs=chatbot,
queue=True, show_progress="hidden")
demo.launch()
|