Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
Β·
c760a78
1
Parent(s):
a746ceb
init commit
Browse files- .gitignore +2 -0
- README.md +6 -5
- api_schema.py +154 -0
- app.py +476 -4
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
**/__pycache__/**
|
2 |
+
**/tmp/**
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Gradio Front Interface
|
3 |
+
emoji: π
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.44.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.12.7
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
api_schema.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from io import BytesIO
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from pydantic import BaseModel, ConfigDict
|
7 |
+
|
8 |
+
|
9 |
+
class AbortController(ABC):
|
10 |
+
def is_alive(self) -> bool:
|
11 |
+
raise NotImplementedError
|
12 |
+
|
13 |
+
|
14 |
+
class NeverAbortedController(AbortController):
|
15 |
+
def is_alive(self) -> bool:
|
16 |
+
return True
|
17 |
+
|
18 |
+
|
19 |
+
def is_none_or_alive(abort_controller: AbortController | None) -> bool:
|
20 |
+
return abort_controller is None or abort_controller.is_alive()
|
21 |
+
|
22 |
+
|
23 |
+
class ModelNameResponse(BaseModel):
|
24 |
+
model_name: str
|
25 |
+
|
26 |
+
|
27 |
+
class TokenizedMessage(BaseModel):
|
28 |
+
role: Literal["user", "assistant"]
|
29 |
+
content: list[list[int]]
|
30 |
+
"""[audio_channels+1, time_steps]"""
|
31 |
+
|
32 |
+
def time_steps(self) -> int:
|
33 |
+
return len(self.content[0])
|
34 |
+
|
35 |
+
def append(self, chunk: list[list[int]]):
|
36 |
+
assert len(chunk) == len(self.content), "Incompatible chunk length"
|
37 |
+
assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape"
|
38 |
+
for content_channel, chunk_channel in zip(self.content, chunk):
|
39 |
+
content_channel.extend(chunk_channel)
|
40 |
+
|
41 |
+
|
42 |
+
class TokenizedConversation(BaseModel):
|
43 |
+
messages: list[TokenizedMessage]
|
44 |
+
|
45 |
+
def time_steps(self) -> int:
|
46 |
+
return sum(msg.time_steps() for msg in self.messages)
|
47 |
+
|
48 |
+
def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]":
|
49 |
+
sum_time_steps = 0
|
50 |
+
selected_messages: list[TokenizedMessage] = []
|
51 |
+
|
52 |
+
for msg in reversed(self.messages):
|
53 |
+
cur_time_steps = msg.time_steps()
|
54 |
+
if sum_time_steps + cur_time_steps > max_time_steps:
|
55 |
+
break
|
56 |
+
sum_time_steps += cur_time_steps
|
57 |
+
selected_messages.append(msg)
|
58 |
+
|
59 |
+
return list(reversed(selected_messages))
|
60 |
+
|
61 |
+
|
62 |
+
class ChatAudioBytes(BaseModel):
|
63 |
+
model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64")
|
64 |
+
sample_rate: int
|
65 |
+
audio_data: bytes
|
66 |
+
"""
|
67 |
+
shape = (channels, samples) or (samples,);
|
68 |
+
dtype = int16 or float32
|
69 |
+
"""
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes":
|
73 |
+
buf = BytesIO()
|
74 |
+
np.save(buf, audio[1])
|
75 |
+
return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue())
|
76 |
+
|
77 |
+
def to_audio(self) -> tuple[int, np.ndarray]:
|
78 |
+
buf = BytesIO(self.audio_data)
|
79 |
+
audio_np = np.load(buf)
|
80 |
+
return self.sample_rate, audio_np
|
81 |
+
|
82 |
+
|
83 |
+
class ChatResponseItem(BaseModel):
|
84 |
+
tokenized_input: TokenizedMessage | None = None
|
85 |
+
token_chunk: list[list[int]] | None = None
|
86 |
+
"""[audio_channels+1, time_steps]"""
|
87 |
+
text_chunk: str | None = None
|
88 |
+
audio_chunk: ChatAudioBytes | None = None
|
89 |
+
end_of_stream: bool | None = None
|
90 |
+
"""Represent Special token <|eostm|>"""
|
91 |
+
end_of_transcription: bool | None = None
|
92 |
+
"""Represent Special token <|eot|> (not <|endoftext|>)"""
|
93 |
+
stop_reason: str | None = None
|
94 |
+
"""The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted"""
|
95 |
+
|
96 |
+
|
97 |
+
class AssistantStyle(BaseModel):
|
98 |
+
preset_character: str | None = None
|
99 |
+
custom_character_prompt: str | None = None
|
100 |
+
|
101 |
+
preset_voice: str | None = None
|
102 |
+
custom_voice: ChatAudioBytes | None = None
|
103 |
+
|
104 |
+
|
105 |
+
class SamplerConfig(BaseModel):
|
106 |
+
"""
|
107 |
+
Sampling configuration for text/audio generation.
|
108 |
+
|
109 |
+
- If some fields are not set, their effects are disabled.
|
110 |
+
- If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined.
|
111 |
+
- Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling.
|
112 |
+
"""
|
113 |
+
|
114 |
+
temperature: float | None = None
|
115 |
+
top_k: int | None = None
|
116 |
+
top_p: float | None = None
|
117 |
+
|
118 |
+
def normalized(self) -> tuple[float, int, float]:
|
119 |
+
"""
|
120 |
+
Returns:
|
121 |
+
A tuple (temperature, top_k, top_p) with normalized values.
|
122 |
+
"""
|
123 |
+
if (
|
124 |
+
(self.temperature is not None and self.temperature <= 0.0)
|
125 |
+
or (self.top_k is not None and self.top_k <= 1)
|
126 |
+
or (self.top_p is not None and self.top_p <= 0.0)
|
127 |
+
):
|
128 |
+
return (1.0, 1, 1.0)
|
129 |
+
|
130 |
+
def default_clip[T: int | float](
|
131 |
+
value: T | None, default_value: T, min_value: T, max_value: T
|
132 |
+
) -> T:
|
133 |
+
if value is None:
|
134 |
+
return default_value
|
135 |
+
return max(min(value, max_value), min_value)
|
136 |
+
|
137 |
+
temperature = default_clip(self.temperature, 1.0, 0.01, 2.0)
|
138 |
+
top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000)
|
139 |
+
top_p = default_clip(self.top_p, 1.0, 0.01, 1.0)
|
140 |
+
|
141 |
+
return (temperature, top_k, top_p)
|
142 |
+
|
143 |
+
|
144 |
+
class ChatRequestBody(BaseModel):
|
145 |
+
conversation: TokenizedConversation | None = None
|
146 |
+
input_text: str | None = None
|
147 |
+
input_audio: ChatAudioBytes | None = None
|
148 |
+
assistant_style: AssistantStyle | None = None
|
149 |
+
global_sampler_config: SamplerConfig | None = None
|
150 |
+
local_sampler_config: SamplerConfig | None = None
|
151 |
+
|
152 |
+
|
153 |
+
class PresetOptions(BaseModel):
|
154 |
+
options: list[str]
|
app.py
CHANGED
@@ -1,7 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
1 |
+
import argparse
|
2 |
+
import queue
|
3 |
+
import time
|
4 |
+
from threading import Thread
|
5 |
+
from typing import Literal, override
|
6 |
+
import os
|
7 |
+
|
8 |
+
import fastrtc
|
9 |
+
from fastrtc import get_cloudflare_turn_credentials_async
|
10 |
import gradio as gr
|
11 |
+
import httpx
|
12 |
+
import numpy as np
|
13 |
+
from pydantic import BaseModel
|
14 |
+
import random
|
15 |
+
|
16 |
+
|
17 |
+
from api_schema import (
|
18 |
+
AbortController,
|
19 |
+
AssistantStyle,
|
20 |
+
ChatAudioBytes,
|
21 |
+
ChatRequestBody,
|
22 |
+
ChatResponseItem,
|
23 |
+
ModelNameResponse,
|
24 |
+
PresetOptions,
|
25 |
+
SamplerConfig,
|
26 |
+
TokenizedConversation,
|
27 |
+
TokenizedMessage,
|
28 |
+
)
|
29 |
+
|
30 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
31 |
+
if HF_TOKEN is None:
|
32 |
+
print(
|
33 |
+
"β οΈ [WARNING] HF_TOKEN environment variable not found.\n"
|
34 |
+
"WebRTC connections may fail on Hugging Face Spaces because TURN service cannot be used.\n"
|
35 |
+
"π‘ Solution: Go to your Hugging Face Space β Settings β Secrets, "
|
36 |
+
"add a variable named HF_TOKEN or HF_ACCESS_TOKEN with your personal access token (with at least 'read' permission)."
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
print("β
[INFO] HF_TOKEN detected. WebRTC will use Hugging Face TURN service for connectivity.")
|
40 |
+
|
41 |
+
|
42 |
+
url_prefix = os.getenv("URL_PREFIX")
|
43 |
+
server_number = int(os.getenv("NUM_SERVER"))
|
44 |
+
print(url_prefix)
|
45 |
+
print(server_number)
|
46 |
+
|
47 |
+
deployment_server = []
|
48 |
+
for i in range(1, server_number+1):
|
49 |
+
url = url_prefix + str(i) + ".hf.space"
|
50 |
+
deployment_server.append(url)
|
51 |
+
print(deployment_server)
|
52 |
+
|
53 |
+
|
54 |
+
class Args(BaseModel):
|
55 |
+
host: str
|
56 |
+
port: int
|
57 |
+
concurrency_limit: int
|
58 |
+
share: bool
|
59 |
+
debug: bool
|
60 |
+
chat_server: str
|
61 |
+
tag: str | None = None
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def parse_args(cls):
|
65 |
+
parser = argparse.ArgumentParser(description="Xiaomi MiMo-Audio Chat")
|
66 |
+
parser.add_argument("--host", default="0.0.0.0")
|
67 |
+
parser.add_argument("--port", type=int, default=8087)
|
68 |
+
parser.add_argument("--concurrency-limit", type=int, default=40)
|
69 |
+
parser.add_argument("--share", action="store_true")
|
70 |
+
parser.add_argument("--debug", action="store_true")
|
71 |
+
parser.add_argument(
|
72 |
+
"-S",
|
73 |
+
"--chat-server",
|
74 |
+
dest="chat_server",
|
75 |
+
type=str,
|
76 |
+
default="deployment_docker_1",
|
77 |
+
)
|
78 |
+
parser.add_argument("--tag", type=str)
|
79 |
+
|
80 |
+
args = parser.parse_args()
|
81 |
+
return cls.model_validate(vars(args))
|
82 |
+
|
83 |
+
def chat_server_url(self):
|
84 |
+
return deployment_server[random.randint(0,server_number-1)]
|
85 |
+
# if self.chat_server in global_chat_server_map:
|
86 |
+
# return global_chat_server_map[self.chat_server]
|
87 |
+
|
88 |
+
# return self.chat_server
|
89 |
+
|
90 |
+
|
91 |
+
class ConversationManager:
|
92 |
+
def __init__(self, assistant_style: AssistantStyle | None = None):
|
93 |
+
self.conversation = TokenizedConversation(messages=[])
|
94 |
+
self.turn = 0
|
95 |
+
self.assistant_style = assistant_style
|
96 |
+
self.last_access_time = time.monotonic()
|
97 |
+
self.collected_audio_chunks: list[np.ndarray] = []
|
98 |
+
|
99 |
+
def new_turn(self):
|
100 |
+
self.turn += 1
|
101 |
+
self.last_access_time = time.monotonic()
|
102 |
+
return ConversationAbortController(self)
|
103 |
+
|
104 |
+
def is_idle(self, idle_timeout: float) -> bool:
|
105 |
+
return time.monotonic() - self.last_access_time > idle_timeout
|
106 |
+
|
107 |
+
def append_audio_chunk(self, audio_chunk: tuple[int, np.ndarray]):
|
108 |
+
sr, audio_data = audio_chunk
|
109 |
+
assert sr == 24000, "Only 24kHz audio is supported"
|
110 |
+
if audio_data.ndim > 1:
|
111 |
+
# [channels, samples] -> [samples,]
|
112 |
+
# Not Gradio style
|
113 |
+
audio_data = audio_data.mean(axis=0).astype(np.int16)
|
114 |
+
self.collected_audio_chunks.append(audio_data)
|
115 |
+
|
116 |
+
def all_collected_audio(self) -> tuple[int, np.ndarray]:
|
117 |
+
sr = 24000
|
118 |
+
audio_data = np.concatenate(self.collected_audio_chunks)
|
119 |
+
return sr, audio_data
|
120 |
+
|
121 |
+
def chat(
|
122 |
+
self,
|
123 |
+
url: httpx.URL,
|
124 |
+
chat_id: int,
|
125 |
+
input_audio: tuple[int, np.ndarray],
|
126 |
+
global_sampler_config: SamplerConfig | None = None,
|
127 |
+
local_sampler_config: SamplerConfig | None = None,
|
128 |
+
):
|
129 |
+
controller = self.new_turn()
|
130 |
+
chat_queue = queue.Queue[ChatResponseItem | None]()
|
131 |
+
|
132 |
+
def chat_task():
|
133 |
+
req = ChatRequestBody(
|
134 |
+
conversation=self.conversation,
|
135 |
+
input_audio=ChatAudioBytes.from_audio(input_audio),
|
136 |
+
assistant_style=self.assistant_style,
|
137 |
+
global_sampler_config=global_sampler_config,
|
138 |
+
local_sampler_config=local_sampler_config,
|
139 |
+
)
|
140 |
+
first_output = True
|
141 |
+
with httpx.Client() as client:
|
142 |
+
headers = {
|
143 |
+
"Content-Type": "application/json",
|
144 |
+
"Authorization": f"Bearer {HF_TOKEN}", # <-- ε θΏδΈθ‘
|
145 |
+
}
|
146 |
+
with client.stream(
|
147 |
+
method="POST",
|
148 |
+
url=url,
|
149 |
+
content=req.model_dump_json(),
|
150 |
+
headers=headers,
|
151 |
+
) as response:
|
152 |
+
if response.status_code != 200:
|
153 |
+
raise RuntimeError(f"Error {response.status_code}")
|
154 |
+
|
155 |
+
for line in response.iter_lines():
|
156 |
+
if not controller.is_alive():
|
157 |
+
print(f"[{chat_id=}] Streaming aborted by user")
|
158 |
+
break
|
159 |
+
if time.monotonic() - consumer_alive_time > 1.0:
|
160 |
+
print(f"[{chat_id=}] Streaming aborted due to inactivity")
|
161 |
+
break
|
162 |
+
if not line.startswith("data: "):
|
163 |
+
continue
|
164 |
+
line = line.removeprefix("data: ")
|
165 |
+
if line.strip() == "[DONE]":
|
166 |
+
print(f"[{chat_id=}] Streaming finished by server")
|
167 |
+
break
|
168 |
+
|
169 |
+
chunk = ChatResponseItem.model_validate_json(line)
|
170 |
+
|
171 |
+
if chunk.tokenized_input is not None:
|
172 |
+
self.conversation.messages.append(
|
173 |
+
chunk.tokenized_input,
|
174 |
+
)
|
175 |
+
|
176 |
+
if chunk.token_chunk is not None:
|
177 |
+
if first_output:
|
178 |
+
self.conversation.messages.append(
|
179 |
+
TokenizedMessage(
|
180 |
+
role="assistant",
|
181 |
+
content=chunk.token_chunk,
|
182 |
+
)
|
183 |
+
)
|
184 |
+
first_output = False
|
185 |
+
else:
|
186 |
+
self.conversation.messages[-1].append(
|
187 |
+
chunk.token_chunk,
|
188 |
+
)
|
189 |
+
|
190 |
+
chat_queue.put(chunk)
|
191 |
+
|
192 |
+
chat_queue.put(None)
|
193 |
+
|
194 |
+
Thread(target=chat_task, daemon=True).start()
|
195 |
+
|
196 |
+
while True:
|
197 |
+
consumer_alive_time = time.monotonic()
|
198 |
+
try:
|
199 |
+
item = chat_queue.get(timeout=0.1)
|
200 |
+
if item is None:
|
201 |
+
break
|
202 |
+
yield item
|
203 |
+
self.last_access_time = time.monotonic()
|
204 |
+
except queue.Empty:
|
205 |
+
yield None
|
206 |
+
|
207 |
+
|
208 |
+
class ConversationAbortController(AbortController):
|
209 |
+
manager: ConversationManager
|
210 |
+
cur_turn: int | None
|
211 |
+
|
212 |
+
def __init__(self, manager: ConversationManager):
|
213 |
+
self.manager = manager
|
214 |
+
self.cur_turn = manager.turn
|
215 |
+
|
216 |
+
@override
|
217 |
+
def is_alive(self) -> bool:
|
218 |
+
return self.manager.turn == self.cur_turn
|
219 |
+
|
220 |
+
def abort(self) -> None:
|
221 |
+
self.cur_turn = None
|
222 |
+
|
223 |
+
|
224 |
+
chat_id_counter = 0
|
225 |
+
|
226 |
+
|
227 |
+
def new_chat_id():
|
228 |
+
global chat_id_counter
|
229 |
+
chat_id = chat_id_counter
|
230 |
+
chat_id_counter += 1
|
231 |
+
return chat_id
|
232 |
+
|
233 |
+
|
234 |
+
def main():
|
235 |
+
args = Args.parse_args()
|
236 |
+
|
237 |
+
print("Starting WebRTC server")
|
238 |
+
|
239 |
+
conversations: dict[str, ConversationManager] = {}
|
240 |
+
|
241 |
+
def cleanup_idle_conversations():
|
242 |
+
idle_timeout = 30 * 60.0 # 30 minutes
|
243 |
+
while True:
|
244 |
+
time.sleep(60)
|
245 |
+
to_delete = []
|
246 |
+
for webrtc_id, manager in conversations.items():
|
247 |
+
if manager.is_idle(idle_timeout):
|
248 |
+
to_delete.append(webrtc_id)
|
249 |
+
for webrtc_id in to_delete:
|
250 |
+
print(f"Cleaning up idle conversation {webrtc_id}")
|
251 |
+
del conversations[webrtc_id]
|
252 |
+
|
253 |
+
Thread(target=cleanup_idle_conversations, daemon=True).start()
|
254 |
+
|
255 |
+
def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
|
256 |
+
url = httpx.URL(args.chat_server_url()).join(f"/preset/{category}")
|
257 |
+
headers = {
|
258 |
+
"Authorization": f"Bearer {HF_TOKEN}" # <-- ε δΈ token
|
259 |
+
}
|
260 |
+
with httpx.Client() as client:
|
261 |
+
response = client.get(url, headers=headers)
|
262 |
+
if response.status_code == 200:
|
263 |
+
return PresetOptions.model_validate_json(response.text).options
|
264 |
+
return ["[default]"]
|
265 |
+
|
266 |
+
def get_model_name() -> str:
|
267 |
+
url = httpx.URL(args.chat_server_url()).join("/model-name")
|
268 |
+
headers = {
|
269 |
+
"Authorization": f"Bearer {HF_TOKEN}" # <-- ε δΈ token
|
270 |
+
}
|
271 |
+
with httpx.Client() as client:
|
272 |
+
response = client.get(url, headers=headers)
|
273 |
+
if response.status_code == 200:
|
274 |
+
return ModelNameResponse.model_validate_json(response.text).model_name
|
275 |
+
return "unknown"
|
276 |
+
|
277 |
+
def load_initial_data():
|
278 |
+
model_name = get_model_name()
|
279 |
+
title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
|
280 |
+
if args.tag is not None:
|
281 |
+
title = f"{args.tag} - {title}"
|
282 |
+
character_choices = get_preset_list("character")
|
283 |
+
voice_choices = get_preset_list("voice")
|
284 |
+
return (
|
285 |
+
gr.update(value=f"# {title}"),
|
286 |
+
gr.update(choices=character_choices),
|
287 |
+
gr.update(choices=voice_choices),
|
288 |
+
)
|
289 |
+
|
290 |
+
def response(
|
291 |
+
input_audio: tuple[int, np.ndarray],
|
292 |
+
webrtc_id: str,
|
293 |
+
preset_character: str | None,
|
294 |
+
preset_voice: str | None,
|
295 |
+
custom_character_prompt: str | None,
|
296 |
+
):
|
297 |
+
headers = {
|
298 |
+
"Authorization": f"Bearer {HF_TOKEN}" # <-- ε δΈ token
|
299 |
+
}
|
300 |
+
# deprecate gc
|
301 |
+
# with httpx.Client() as client:
|
302 |
+
# client.get(httpx.URL(args.chat_server_url()).join("/gc"), headers=headers)
|
303 |
+
nonlocal conversations
|
304 |
+
|
305 |
+
if webrtc_id not in conversations:
|
306 |
+
custom_character_prompt = custom_character_prompt.strip()
|
307 |
+
if custom_character_prompt == "":
|
308 |
+
custom_character_prompt = None
|
309 |
+
conversations[webrtc_id] = ConversationManager(
|
310 |
+
assistant_style=AssistantStyle(
|
311 |
+
preset_character=preset_character,
|
312 |
+
custom_character_prompt=custom_character_prompt,
|
313 |
+
preset_voice=preset_voice,
|
314 |
+
)
|
315 |
+
)
|
316 |
+
|
317 |
+
manager = conversations[webrtc_id]
|
318 |
+
|
319 |
+
sr, audio_data = input_audio
|
320 |
+
chat_id = new_chat_id()
|
321 |
+
print(f"WebRTC {webrtc_id} [{chat_id=}]: Input {audio_data.shape[1] / sr}s")
|
322 |
+
|
323 |
+
# Record input audio
|
324 |
+
manager.append_audio_chunk(input_audio)
|
325 |
+
|
326 |
+
output_text = ""
|
327 |
+
status_text = "βοΈ Preparing..."
|
328 |
+
text_active = False
|
329 |
+
audio_active = False
|
330 |
+
collected_audio: tuple[int, np.ndarray] | None = None
|
331 |
+
|
332 |
+
def additional_outputs():
|
333 |
+
return fastrtc.AdditionalOutputs(
|
334 |
+
output_text,
|
335 |
+
status_text,
|
336 |
+
collected_audio,
|
337 |
+
)
|
338 |
+
|
339 |
+
yield additional_outputs()
|
340 |
+
|
341 |
+
try:
|
342 |
+
url = httpx.URL(args.chat_server_url()).join("/audio-chat")
|
343 |
+
for chunk in manager.chat(
|
344 |
+
url,
|
345 |
+
chat_id,
|
346 |
+
input_audio,
|
347 |
+
):
|
348 |
+
if chunk is None:
|
349 |
+
# Test if consumer is still alive
|
350 |
+
yield None
|
351 |
+
continue
|
352 |
+
|
353 |
+
if chunk.text_chunk is not None:
|
354 |
+
text_active = True
|
355 |
+
output_text += chunk.text_chunk
|
356 |
+
|
357 |
+
if chunk.end_of_transcription:
|
358 |
+
text_active = False
|
359 |
+
|
360 |
+
if chunk.audio_chunk is not None:
|
361 |
+
audio_active = True
|
362 |
+
audio = chunk.audio_chunk.to_audio()
|
363 |
+
manager.append_audio_chunk(audio)
|
364 |
+
yield audio
|
365 |
+
|
366 |
+
if chunk.end_of_stream:
|
367 |
+
audio_active = False
|
368 |
+
|
369 |
+
if text_active and audio_active:
|
370 |
+
status_text = "π¬+π Mixed"
|
371 |
+
elif text_active:
|
372 |
+
status_text = "π¬ Text"
|
373 |
+
elif audio_active:
|
374 |
+
status_text = "π Audio"
|
375 |
+
|
376 |
+
if chunk.stop_reason is not None:
|
377 |
+
status_text = f"β
Finished: {chunk.stop_reason}"
|
378 |
+
|
379 |
+
yield additional_outputs()
|
380 |
+
|
381 |
+
except RuntimeError as e:
|
382 |
+
status_text = f"β Error: {e}"
|
383 |
+
yield additional_outputs()
|
384 |
+
|
385 |
+
collected_audio = manager.all_collected_audio()
|
386 |
+
yield additional_outputs()
|
387 |
+
|
388 |
+
title = "Xiaomi MiMo-Audio WebRTC"
|
389 |
+
if args.tag is not None:
|
390 |
+
title = f"{args.tag} - {title}"
|
391 |
+
|
392 |
+
with gr.Blocks(title=title) as demo:
|
393 |
+
title_markdown = gr.Markdown(f"# {title}")
|
394 |
+
with gr.Row():
|
395 |
+
with gr.Column():
|
396 |
+
chat = fastrtc.WebRTC(
|
397 |
+
label="WebRTC Chat",
|
398 |
+
modality="audio",
|
399 |
+
mode="send-receive",
|
400 |
+
full_screen=False,
|
401 |
+
rtc_configuration=get_cloudflare_turn_credentials_async
|
402 |
+
# server_rtc_configuration=get_hf_turn_credentials(ttl=600 * 1000),
|
403 |
+
# rtc_configuration=get_hf_turn_credentials,
|
404 |
+
)
|
405 |
+
output_text = gr.Textbox(label="Output", lines=3, interactive=False)
|
406 |
+
status_text = gr.Textbox(label="Status", lines=1, interactive=False)
|
407 |
+
|
408 |
+
with gr.Accordion("Advanced", open=False):
|
409 |
+
collected_audio = gr.Audio(
|
410 |
+
label="Full Audio",
|
411 |
+
type="numpy",
|
412 |
+
format="wav",
|
413 |
+
interactive=False,
|
414 |
+
)
|
415 |
+
|
416 |
+
with gr.Column():
|
417 |
+
with gr.Accordion("Settings Help"):
|
418 |
+
gr.Markdown(
|
419 |
+
"- `Preset Prompt` controls the response style.\n"
|
420 |
+
"- `Preset Voice` controls the speaking tone.\n"
|
421 |
+
"- `Custom Prompt` lets you define the response style in natural language (overrides `Preset Prompt`).\n"
|
422 |
+
"- For best results, choose prompts and voices that match your language.\n"
|
423 |
+
"- To apply new settings, end the current conversation and start a new one."
|
424 |
+
)
|
425 |
+
preset_character_dropdown = gr.Dropdown(
|
426 |
+
label="π Preset Prompt",
|
427 |
+
choices=["[default]"],
|
428 |
+
)
|
429 |
+
preset_voice_dropdown = gr.Dropdown(
|
430 |
+
label="π€ Preset Voice",
|
431 |
+
choices=["[default]"],
|
432 |
+
)
|
433 |
+
custom_character_prompt = gr.Textbox(
|
434 |
+
label="π οΈ Custom Prompt",
|
435 |
+
placeholder="For example: You are Xiaomi MiMo-Audio, a large language model trained by Xiaomi. You are chatting with a user over voice.",
|
436 |
+
lines=2,
|
437 |
+
interactive=True,
|
438 |
+
)
|
439 |
+
|
440 |
+
chat.stream(
|
441 |
+
fastrtc.ReplyOnPause(
|
442 |
+
response,
|
443 |
+
input_sample_rate=24000,
|
444 |
+
output_sample_rate=24000,
|
445 |
+
model_options=fastrtc.SileroVadOptions(
|
446 |
+
threshold=0.7,
|
447 |
+
min_silence_duration_ms=1000,
|
448 |
+
),
|
449 |
+
),
|
450 |
+
inputs=[
|
451 |
+
chat,
|
452 |
+
preset_character_dropdown,
|
453 |
+
preset_voice_dropdown,
|
454 |
+
custom_character_prompt,
|
455 |
+
],
|
456 |
+
concurrency_limit=args.concurrency_limit,
|
457 |
+
outputs=[chat],
|
458 |
+
)
|
459 |
+
chat.on_additional_outputs(
|
460 |
+
lambda *args: args,
|
461 |
+
outputs=[output_text, status_text, collected_audio],
|
462 |
+
concurrency_limit=args.concurrency_limit,
|
463 |
+
show_progress="hidden",
|
464 |
+
)
|
465 |
+
|
466 |
+
demo.load(
|
467 |
+
load_initial_data,
|
468 |
+
inputs=[],
|
469 |
+
outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
|
470 |
+
)
|
471 |
+
demo.queue(
|
472 |
+
default_concurrency_limit=args.concurrency_limit,
|
473 |
+
)
|
474 |
+
|
475 |
+
demo.launch()
|
476 |
|
|
|
|
|
477 |
|
478 |
+
if __name__ == "__main__":
|
479 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.116.1
|
2 |
+
pydantic==2.11.7
|
3 |
+
fastrtc[vad]==0.0.33
|
4 |
+
gradio==5.35.0
|
5 |
+
httpx==0.28.1
|