mimo_audio_chat / api_schema.py
Corle-heyongzhe's picture
init commit
c760a78
from abc import ABC
from io import BytesIO
from typing import Literal
import numpy as np
from pydantic import BaseModel, ConfigDict
class AbortController(ABC):
def is_alive(self) -> bool:
raise NotImplementedError
class NeverAbortedController(AbortController):
def is_alive(self) -> bool:
return True
def is_none_or_alive(abort_controller: AbortController | None) -> bool:
return abort_controller is None or abort_controller.is_alive()
class ModelNameResponse(BaseModel):
model_name: str
class TokenizedMessage(BaseModel):
role: Literal["user", "assistant"]
content: list[list[int]]
"""[audio_channels+1, time_steps]"""
def time_steps(self) -> int:
return len(self.content[0])
def append(self, chunk: list[list[int]]):
assert len(chunk) == len(self.content), "Incompatible chunk length"
assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape"
for content_channel, chunk_channel in zip(self.content, chunk):
content_channel.extend(chunk_channel)
class TokenizedConversation(BaseModel):
messages: list[TokenizedMessage]
def time_steps(self) -> int:
return sum(msg.time_steps() for msg in self.messages)
def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]":
sum_time_steps = 0
selected_messages: list[TokenizedMessage] = []
for msg in reversed(self.messages):
cur_time_steps = msg.time_steps()
if sum_time_steps + cur_time_steps > max_time_steps:
break
sum_time_steps += cur_time_steps
selected_messages.append(msg)
return list(reversed(selected_messages))
class ChatAudioBytes(BaseModel):
model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64")
sample_rate: int
audio_data: bytes
"""
shape = (channels, samples) or (samples,);
dtype = int16 or float32
"""
@classmethod
def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes":
buf = BytesIO()
np.save(buf, audio[1])
return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue())
def to_audio(self) -> tuple[int, np.ndarray]:
buf = BytesIO(self.audio_data)
audio_np = np.load(buf)
return self.sample_rate, audio_np
class ChatResponseItem(BaseModel):
tokenized_input: TokenizedMessage | None = None
token_chunk: list[list[int]] | None = None
"""[audio_channels+1, time_steps]"""
text_chunk: str | None = None
audio_chunk: ChatAudioBytes | None = None
end_of_stream: bool | None = None
"""Represent Special token <|eostm|>"""
end_of_transcription: bool | None = None
"""Represent Special token <|eot|> (not <|endoftext|>)"""
stop_reason: str | None = None
"""The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted"""
class AssistantStyle(BaseModel):
preset_character: str | None = None
custom_character_prompt: str | None = None
preset_voice: str | None = None
custom_voice: ChatAudioBytes | None = None
class SamplerConfig(BaseModel):
"""
Sampling configuration for text/audio generation.
- If some fields are not set, their effects are disabled.
- If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined.
- Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling.
"""
temperature: float | None = None
top_k: int | None = None
top_p: float | None = None
def normalized(self) -> tuple[float, int, float]:
"""
Returns:
A tuple (temperature, top_k, top_p) with normalized values.
"""
if (
(self.temperature is not None and self.temperature <= 0.0)
or (self.top_k is not None and self.top_k <= 1)
or (self.top_p is not None and self.top_p <= 0.0)
):
return (1.0, 1, 1.0)
def default_clip[T: int | float](
value: T | None, default_value: T, min_value: T, max_value: T
) -> T:
if value is None:
return default_value
return max(min(value, max_value), min_value)
temperature = default_clip(self.temperature, 1.0, 0.01, 2.0)
top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000)
top_p = default_clip(self.top_p, 1.0, 0.01, 1.0)
return (temperature, top_k, top_p)
class ChatRequestBody(BaseModel):
conversation: TokenizedConversation | None = None
input_text: str | None = None
input_audio: ChatAudioBytes | None = None
assistant_style: AssistantStyle | None = None
global_sampler_config: SamplerConfig | None = None
local_sampler_config: SamplerConfig | None = None
class PresetOptions(BaseModel):
options: list[str]