Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import queue | |
import random | |
import time | |
from threading import Thread | |
from typing import Any, Callable, Literal, override | |
import fastrtc | |
import gradio as gr | |
import httpx | |
import numpy as np | |
from api_schema import ( | |
AbortController, | |
AssistantStyle, | |
ChatAudioBytes, | |
ChatRequestBody, | |
ChatResponseItem, | |
ModelNameResponse, | |
PresetOptions, | |
SamplerConfig, | |
TokenizedConversation, | |
TokenizedMessage, | |
) | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
SERVER_LIST = os.getenv("SERVER_LIST") | |
TURN_KEY_ID = os.getenv("TURN_KEY_ID") | |
TURN_KEY_API_TOKEN = os.getenv("TURN_KEY_API_TOKEN") | |
CONCURRENCY_LIMIT = os.getenv("CONCURRENCY_LIMIT") | |
assert SERVER_LIST is not None, "SERVER_LIST environment variable is required." | |
assert TURN_KEY_ID is not None and TURN_KEY_API_TOKEN is not None, ( | |
"TURN_KEY_ID and TURN_KEY_API_TOKEN environment variables are required " | |
) | |
deployment_server = [ | |
server_url.strip() for server_url in SERVER_LIST.split(",") if server_url.strip() | |
] | |
assert len(deployment_server) > 0, "SERVER_LIST must contain at least one server URL." | |
default_concurrency_limit = 32 | |
try: | |
concurrency_limit = ( | |
int(CONCURRENCY_LIMIT) | |
if CONCURRENCY_LIMIT is not None | |
else default_concurrency_limit | |
) | |
except ValueError: | |
concurrency_limit = default_concurrency_limit | |
def chat_server_url(pathname: str = "/") -> httpx.URL: | |
n = len(deployment_server) | |
server_idx = random.randint(0, n - 1) | |
host = deployment_server[server_idx] | |
return httpx.URL(host).join(pathname) | |
def auth_headers() -> dict[str, str]: | |
if HF_TOKEN is None: | |
return {} | |
return {"Authorization": f"Bearer {HF_TOKEN}"} | |
def get_cloudflare_turn_credentials( | |
ttl: int = 1200, # 20 minutes | |
) -> dict[str, Any]: | |
with httpx.Client() as client: | |
response = client.post( | |
f"https://rtc.live.cloudflare.com/v1/turn/keys/{TURN_KEY_ID}/credentials/generate-ice-servers", | |
headers={ | |
"Authorization": f"Bearer {TURN_KEY_API_TOKEN}", | |
"Content-Type": "application/json", | |
}, | |
json={"ttl": ttl}, | |
) | |
if response.is_success: | |
return response.json() | |
else: | |
raise Exception( | |
f"Failed to get TURN credentials: {response.status_code} {response.text}" | |
) | |
class NeverVAD(fastrtc.PauseDetectionModel): | |
def vad(self, *_args, **_kwargs): | |
raise RuntimeError("NeverVAD should not be called.") | |
def warmup(self): | |
pass | |
class ReplyOnMuted(fastrtc.ReplyOnPause): | |
def __init__( | |
self, | |
fn: fastrtc.reply_on_pause.ReplyFnGenerator, | |
startup_fn: Callable | None = None, | |
can_interrupt: bool = True, | |
needs_args: bool = False, | |
): | |
super().__init__( | |
fn, | |
startup_fn, | |
None, | |
None, | |
can_interrupt, | |
"mono", | |
24000, | |
None, | |
24000, | |
NeverVAD(), | |
needs_args, | |
) | |
def copy(self): | |
return ReplyOnMuted( | |
self.fn, | |
self.startup_fn, | |
self.can_interrupt, | |
self.needs_args, | |
) | |
def determine_pause( | |
self, | |
audio: np.ndarray, # shape [samples,] | |
sampling_rate: int, | |
state: fastrtc.reply_on_pause.AppState, | |
): | |
chunk_length = len(audio) / sampling_rate | |
if chunk_length > 0.1: | |
state.buffer = None | |
if not state.started_talking: | |
if not np.all(abs(audio) < 5): | |
state.started_talking = True | |
self.send_message_sync( | |
fastrtc.utils.create_message("log", "started_talking") | |
) | |
if state.started_talking: | |
if state.stream is None: | |
state.stream = audio | |
else: | |
state.stream = np.concatenate((state.stream, audio)) | |
current_duration = len(state.stream) / sampling_rate | |
if current_duration > 1.0: | |
last_segment = state.stream[-int(sampling_rate * 0.1) :] | |
if np.all(abs(last_segment) < 5): | |
return True | |
return False | |
class ConversationManager: | |
def __init__(self, assistant_style: AssistantStyle | None = None): | |
self.conversation = TokenizedConversation(messages=[]) | |
self.turn = 0 | |
self.assistant_style = assistant_style | |
self.last_access_time = time.monotonic() | |
self.collected_audio_chunks: list[np.ndarray] = [] | |
def new_turn(self): | |
self.turn += 1 | |
self.last_access_time = time.monotonic() | |
return ConversationAbortController(self) | |
def is_idle(self, idle_timeout: float) -> bool: | |
return time.monotonic() - self.last_access_time > idle_timeout | |
def append_audio_chunk(self, audio_chunk: tuple[int, np.ndarray]): | |
sr, audio_data = audio_chunk | |
assert sr == 24000, "Only 24kHz audio is supported" | |
if audio_data.ndim > 1: | |
# [channels, samples] -> [samples,] | |
# Not Gradio style | |
audio_data = audio_data.mean(axis=0).astype(np.int16) | |
self.collected_audio_chunks.append(audio_data) | |
def all_collected_audio(self) -> tuple[int, np.ndarray]: | |
sr = 24000 | |
audio_data = np.concatenate(self.collected_audio_chunks) | |
return sr, audio_data | |
def chat( | |
self, | |
url: httpx.URL, | |
chat_id: int, | |
input_audio: tuple[int, np.ndarray], | |
global_sampler_config: SamplerConfig | None = None, | |
local_sampler_config: SamplerConfig | None = None, | |
): | |
controller = self.new_turn() | |
chat_queue = queue.Queue[ChatResponseItem | None]() | |
def chat_task(): | |
req = ChatRequestBody( | |
conversation=self.conversation, | |
input_audio=ChatAudioBytes.from_audio(input_audio), | |
assistant_style=self.assistant_style, | |
global_sampler_config=global_sampler_config, | |
local_sampler_config=local_sampler_config, | |
) | |
first_output = True | |
with httpx.Client() as client: | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {HF_TOKEN}", # <-- ε θΏδΈθ‘ | |
} | |
with client.stream( | |
method="POST", | |
url=url, | |
content=req.model_dump_json(), | |
headers=headers, | |
) as response: | |
if response.status_code != 200: | |
raise RuntimeError(f"Error {response.status_code}") | |
for line in response.iter_lines(): | |
if not controller.is_alive(): | |
print(f"[{chat_id=}] Streaming aborted by user") | |
break | |
if time.monotonic() - consumer_alive_time > 1.0: | |
print(f"[{chat_id=}] Streaming aborted due to inactivity") | |
break | |
if not line.startswith("data: "): | |
continue | |
line = line.removeprefix("data: ") | |
if line.strip() == "[DONE]": | |
print(f"[{chat_id=}] Streaming finished by server") | |
break | |
chunk = ChatResponseItem.model_validate_json(line) | |
if chunk.tokenized_input is not None: | |
self.conversation.messages.append( | |
chunk.tokenized_input, | |
) | |
if chunk.token_chunk is not None: | |
if first_output: | |
self.conversation.messages.append( | |
TokenizedMessage( | |
role="assistant", | |
content=chunk.token_chunk, | |
) | |
) | |
first_output = False | |
else: | |
self.conversation.messages[-1].append( | |
chunk.token_chunk, | |
) | |
chat_queue.put(chunk) | |
chat_queue.put(None) | |
Thread(target=chat_task, daemon=True).start() | |
while True: | |
consumer_alive_time = time.monotonic() | |
try: | |
item = chat_queue.get(timeout=0.1) | |
if item is None: | |
break | |
yield item | |
self.last_access_time = time.monotonic() | |
except queue.Empty: | |
yield None | |
def get_microphone_svg(muted: bool | None = None): | |
muted_svg = '<line x1="1" y1="1" x2="23" y2="23"></line>' if muted else "" | |
return f""" | |
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="feather feather-mic" style="display: inline; vertical-align: middle;"> | |
<path d="M12 1a3 3 0 0 0-3 3v8a3 3 0 0 0 6 0V4a3 3 0 0 0-3-3z"></path> | |
<path d="M19 10v2a7 7 0 0 1-14 0v-2"></path> | |
<line x1="12" y1="19" x2="12" y2="23"></line> | |
<line x1="8" y1="23" x2="16" y2="23"></line> | |
{muted_svg} | |
</svg> | |
""" | |
class ConversationAbortController(AbortController): | |
manager: ConversationManager | |
cur_turn: int | None | |
def __init__(self, manager: ConversationManager): | |
self.manager = manager | |
self.cur_turn = manager.turn | |
def is_alive(self) -> bool: | |
return self.manager.turn == self.cur_turn | |
def abort(self) -> None: | |
self.cur_turn = None | |
chat_id_counter = 0 | |
def new_chat_id(): | |
global chat_id_counter | |
chat_id = chat_id_counter | |
chat_id_counter += 1 | |
return chat_id | |
def main(): | |
print("Starting WebRTC server") | |
conversations: dict[str, ConversationManager] = {} | |
def cleanup_idle_conversations(): | |
idle_timeout = 30 * 60.0 # 30 minutes | |
while True: | |
time.sleep(60) | |
to_delete = [] | |
for webrtc_id, manager in conversations.items(): | |
if manager.is_idle(idle_timeout): | |
to_delete.append(webrtc_id) | |
for webrtc_id in to_delete: | |
print(f"Cleaning up idle conversation {webrtc_id}") | |
del conversations[webrtc_id] | |
Thread(target=cleanup_idle_conversations, daemon=True).start() | |
def get_preset_list(category: Literal["character", "voice"]) -> list[str]: | |
url = chat_server_url(f"/preset/{category}") | |
with httpx.Client() as client: | |
response = client.get(url, headers=auth_headers()) | |
if response.status_code == 200: | |
return PresetOptions.model_validate_json(response.text).options | |
return ["[default]"] | |
def get_model_name() -> str: | |
url = chat_server_url("/model-name") | |
with httpx.Client() as client: | |
response = client.get(url, headers=auth_headers()) | |
if response.status_code == 200: | |
return ModelNameResponse.model_validate_json(response.text).model_name | |
return "unknown" | |
def load_initial_data(): | |
model_name = get_model_name() | |
title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})" | |
character_choices = get_preset_list("character") | |
voice_choices = get_preset_list("voice") | |
return ( | |
gr.update(value=f"# {title}"), | |
gr.update(choices=character_choices), | |
gr.update(choices=voice_choices), | |
) | |
def response( | |
input_audio: tuple[int, np.ndarray], | |
webrtc_id: str, | |
preset_character: str | None, | |
preset_voice: str | None, | |
custom_character_prompt: str | None, | |
): | |
nonlocal conversations | |
if webrtc_id not in conversations: | |
custom_character_prompt = custom_character_prompt.strip() | |
if custom_character_prompt == "": | |
custom_character_prompt = None | |
conversations[webrtc_id] = ConversationManager( | |
assistant_style=AssistantStyle( | |
preset_character=preset_character, | |
custom_character_prompt=custom_character_prompt, | |
preset_voice=preset_voice, | |
) | |
) | |
manager = conversations[webrtc_id] | |
sr, audio_data = input_audio | |
chat_id = new_chat_id() | |
print(f"WebRTC {webrtc_id} [{chat_id=}]: Input {audio_data.shape[1] / sr}s") | |
# Record input audio | |
manager.append_audio_chunk(input_audio) | |
output_text = "" | |
status_text = "βοΈ Preparing..." | |
text_active = False | |
audio_active = False | |
collected_audio: tuple[int, np.ndarray] | None = None | |
def additional_outputs(): | |
return fastrtc.AdditionalOutputs( | |
output_text, | |
status_text, | |
collected_audio, | |
) | |
yield additional_outputs() | |
try: | |
url = chat_server_url("/audio-chat") | |
for chunk in manager.chat( | |
url, | |
chat_id, | |
input_audio, | |
): | |
if chunk is None: | |
# Test if consumer is still alive | |
yield None | |
continue | |
if chunk.text_chunk is not None: | |
text_active = True | |
output_text += chunk.text_chunk | |
if chunk.end_of_transcription: | |
text_active = False | |
if chunk.audio_chunk is not None: | |
audio_active = True | |
audio = chunk.audio_chunk.to_audio() | |
manager.append_audio_chunk(audio) | |
yield audio | |
if chunk.end_of_stream: | |
audio_active = False | |
if text_active and audio_active: | |
status_text = "π¬+π Mixed" | |
elif text_active: | |
status_text = "π¬ Text" | |
elif audio_active: | |
status_text = "π Audio" | |
if chunk.stop_reason is not None: | |
status_text = f"β Finished: {chunk.stop_reason}" | |
yield additional_outputs() | |
except RuntimeError as e: | |
status_text = f"β Error: {e}" | |
yield additional_outputs() | |
collected_audio = manager.all_collected_audio() | |
yield additional_outputs() | |
title = "Xiaomi MiMo-Audio WebRTC" | |
with gr.Blocks(title=title) as demo: | |
title_markdown = gr.Markdown(f"# {title}") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Usage"): | |
gr.HTML( | |
f"<li>Note: FastRTC's built-in VAD is quite sensitive. For better stability across environments, this demo uses a manual end-of-speech flow. It simply detects if the microphone is muted. That may lead to a bad experience when using auto-denoise microphone. We are trying to find a stable VAD model that works well with FastRTC.</li>" | |
f"<li>Click Request Microphone to grant permission, click Record to start a turn, and click Stop to end the turn and clear the conversation history.</li>" | |
f"<li>After you finish speaking, click the microphone icon {get_microphone_svg()} to end your input and wait for MiMo's reply.</li>" | |
f"<li>While MiMo is speaking, you can interrupt by clicking the muted microphone icon {get_microphone_svg(muted=True)} and then speaking a new instruction.</li>" | |
) | |
chat = fastrtc.WebRTC( | |
label="WebRTC Chat", | |
modality="audio", | |
mode="send-receive", | |
full_screen=False, | |
rtc_configuration=get_cloudflare_turn_credentials, | |
) | |
output_text = gr.Textbox(label="Output", lines=3, interactive=False) | |
status_text = gr.Textbox(label="Status", lines=1, interactive=False) | |
with gr.Accordion("Advanced", open=True): | |
collected_audio = gr.Audio( | |
label="Full Audio", | |
type="numpy", | |
format="wav", | |
interactive=False, | |
) | |
with gr.Column(): | |
with gr.Accordion("Settings Help"): | |
gr.Markdown( | |
"- `Preset Prompt` controls the response style.\n" | |
"- `Preset Voice` controls the speaking tone.\n" | |
"- `Custom Prompt` lets you define the response style in natural language (overrides `Preset Prompt`).\n" | |
"- For best results, choose prompts and voices that match your language.\n" | |
"- To apply new settings, end the current conversation and start a new one." | |
) | |
preset_character_dropdown = gr.Dropdown( | |
label="π Preset Prompt", | |
choices=["[default]"], | |
) | |
preset_voice_dropdown = gr.Dropdown( | |
label="π€ Preset Voice", | |
choices=["[default]"], | |
) | |
custom_character_prompt = gr.Textbox( | |
label="π οΈ Custom Prompt", | |
placeholder="For example: You are Xiaomi MiMo-Audio, a large language model trained by Xiaomi. You are chatting with a user over voice.", | |
lines=2, | |
interactive=True, | |
) | |
chat.stream( | |
ReplyOnMuted(response), | |
inputs=[ | |
chat, | |
preset_character_dropdown, | |
preset_voice_dropdown, | |
custom_character_prompt, | |
], | |
concurrency_limit=concurrency_limit, | |
outputs=[chat], | |
) | |
chat.on_additional_outputs( | |
lambda *args: args, | |
outputs=[output_text, status_text, collected_audio], | |
concurrency_limit=concurrency_limit, | |
show_progress="hidden", | |
) | |
demo.load( | |
load_initial_data, | |
inputs=[], | |
outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown], | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |