freddyaboulton HF staff commited on
Commit
a461509
1 Parent(s): 04b5051

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ import gradio as gr
3
+ from gradio_webrtc import WebRTC, StreamHandler, AdditionalOutputs
4
+ from numpy import ndarray
5
+ import sphn
6
+ import websockets.sync.client
7
+ import numpy as np
8
+
9
+
10
+ class MoshiHandler(StreamHandler):
11
+
12
+ def __init__(self,
13
+ url: str,
14
+ expected_layout: Literal['mono', 'stereo'] = "mono",
15
+ output_sample_rate: int = 24000,
16
+ output_frame_size: int = 480) -> None:
17
+ self.url = url
18
+ proto, without_proto = self.url.split('://', 1)
19
+ if proto in ['ws', 'http']:
20
+ proto = "ws"
21
+ elif proto in ['wss', 'https']:
22
+ proto = "wss"
23
+
24
+ self._generator = None
25
+ self.output_chunk_size = 1920
26
+ self.ws = None
27
+ self.ws_url = f"{proto}://{without_proto}/api/chat"
28
+ self.stream_reader = sphn.OpusStreamReader(output_sample_rate)
29
+ self.stream_writer = sphn.OpusStreamWriter(output_sample_rate)
30
+ self.all_output_data = None
31
+ super().__init__(expected_layout, output_sample_rate, output_frame_size,
32
+ input_sample_rate=24000)
33
+
34
+ def receive(self, frame: tuple[int, ndarray]) -> None:
35
+ if not self.ws:
36
+ self.ws = websockets.sync.client.connect(self.ws_url)
37
+ _, array = frame
38
+ array = array.squeeze().astype(np.float32) / 32768.0
39
+ self.stream_writer.append_pcm(array)
40
+ bytes = b"\x01" + self.stream_writer.read_bytes()
41
+ self.ws.send(bytes)
42
+
43
+ def generator(self):
44
+ for message in self.ws:
45
+ if len(message) == 0:
46
+ yield None
47
+ kind = message[0]
48
+ if kind == 1:
49
+ payload = message[1:]
50
+ self.stream_reader.append_bytes(payload)
51
+ pcm = self.stream_reader.read_pcm()
52
+ if self.all_output_data is None:
53
+ self.all_output_data = pcm
54
+ else:
55
+ self.all_output_data = np.concatenate((self.all_output_data, pcm))
56
+ while self.all_output_data.shape[-1] >= self.output_chunk_size:
57
+ yield (self.output_sample_rate, self.all_output_data[: self.output_chunk_size].reshape(1, -1))
58
+ self.all_output_data = np.array(self.all_output_data[self.output_chunk_size :])
59
+ elif kind == 2:
60
+ payload = message[1:]
61
+ yield AdditionalOutputs(payload.decode())
62
+
63
+
64
+ def emit(self) -> tuple[int, ndarray] | None:
65
+ if not self.ws:
66
+ return
67
+ if not self._generator:
68
+ self._generator = self.generator()
69
+ try:
70
+ return next(self._generator)
71
+ except StopIteration:
72
+ self.reset()
73
+
74
+ def reset(self) -> None:
75
+ self._generator = None
76
+ self.all_output_data = None
77
+
78
+ def copy(self) -> StreamHandler:
79
+ return MoshiHandler(self.url,
80
+ self.expected_layout,
81
+ self.output_sample_rate, self.output_frame_size)
82
+
83
+ def shutdown(self) -> None:
84
+ if self.ws:
85
+ self.ws.close()
86
+
87
+
88
+
89
+ with gr.Blocks() as demo:
90
+ gr.HTML(
91
+ """
92
+ <div style='text-align: center'>
93
+ <h1>
94
+ Talk To Moshi (Powered by WebRTC ⚡️)
95
+ </h1>
96
+ <p>
97
+ Each conversation is limited to 90 seconds. Once the time limit is up you can rejoin the conversation.
98
+ </p>
99
+ </div>
100
+ """
101
+ )
102
+ response = gr.State(value="")
103
+ chatbot = gr.Chatbot(type="messages", value=[])
104
+ webrtc = WebRTC(label="Conversation", modality="audio", mode="send-receive", rtc_configuration=None)
105
+ webrtc.stream(MoshiHandler("https://freddyaboulton-moshi-server.hf.space"),
106
+ inputs=[webrtc, chatbot], outputs=[webrtc], time_limit=90)
107
+
108
+ def on_text(state, response):
109
+ print("response", response)
110
+ return state + response
111
+
112
+ def add_text(chat_history, response):
113
+ if len(chat_history) == 0:
114
+ chat_history.append({"role": "assistant", "content": response})
115
+ else:
116
+ chat_history[-1]["content"] = response
117
+ return chat_history
118
+
119
+ webrtc.on_additional_outputs(on_text,
120
+ inputs=[response], outputs=response,
121
+ queue=True,
122
+ show_progress="hidden")
123
+ response.change(add_text, inputs=[chatbot, response], outputs=chatbot,
124
+ queue=True, show_progress="hidden")
125
+
126
+ demo.launch()