from abc import ABC from io import BytesIO from typing import Literal import numpy as np from pydantic import BaseModel, ConfigDict class AbortController(ABC): def is_alive(self) -> bool: raise NotImplementedError class NeverAbortedController(AbortController): def is_alive(self) -> bool: return True def is_none_or_alive(abort_controller: AbortController | None) -> bool: return abort_controller is None or abort_controller.is_alive() class ModelNameResponse(BaseModel): model_name: str class TokenizedMessage(BaseModel): role: Literal["user", "assistant"] content: list[list[int]] """[audio_channels+1, time_steps]""" def time_steps(self) -> int: return len(self.content[0]) def append(self, chunk: list[list[int]]): assert len(chunk) == len(self.content), "Incompatible chunk length" assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape" for content_channel, chunk_channel in zip(self.content, chunk): content_channel.extend(chunk_channel) class TokenizedConversation(BaseModel): messages: list[TokenizedMessage] def time_steps(self) -> int: return sum(msg.time_steps() for msg in self.messages) def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]": sum_time_steps = 0 selected_messages: list[TokenizedMessage] = [] for msg in reversed(self.messages): cur_time_steps = msg.time_steps() if sum_time_steps + cur_time_steps > max_time_steps: break sum_time_steps += cur_time_steps selected_messages.append(msg) return list(reversed(selected_messages)) class ChatAudioBytes(BaseModel): model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64") sample_rate: int audio_data: bytes """ shape = (channels, samples) or (samples,); dtype = int16 or float32 """ @classmethod def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes": buf = BytesIO() np.save(buf, audio[1]) return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue()) def to_audio(self) -> tuple[int, np.ndarray]: buf = BytesIO(self.audio_data) audio_np = np.load(buf) return self.sample_rate, audio_np class ChatResponseItem(BaseModel): tokenized_input: TokenizedMessage | None = None token_chunk: list[list[int]] | None = None """[audio_channels+1, time_steps]""" text_chunk: str | None = None audio_chunk: ChatAudioBytes | None = None end_of_stream: bool | None = None """Represent Special token <|eostm|>""" end_of_transcription: bool | None = None """Represent Special token <|eot|> (not <|endoftext|>)""" stop_reason: str | None = None """The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted""" class AssistantStyle(BaseModel): preset_character: str | None = None custom_character_prompt: str | None = None preset_voice: str | None = None custom_voice: ChatAudioBytes | None = None class SamplerConfig(BaseModel): """ Sampling configuration for text/audio generation. - If some fields are not set, their effects are disabled. - If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined. - Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling. """ temperature: float | None = None top_k: int | None = None top_p: float | None = None def normalized(self) -> tuple[float, int, float]: """ Returns: A tuple (temperature, top_k, top_p) with normalized values. """ if ( (self.temperature is not None and self.temperature <= 0.0) or (self.top_k is not None and self.top_k <= 1) or (self.top_p is not None and self.top_p <= 0.0) ): return (1.0, 1, 1.0) def default_clip[T: int | float]( value: T | None, default_value: T, min_value: T, max_value: T ) -> T: if value is None: return default_value return max(min(value, max_value), min_value) temperature = default_clip(self.temperature, 1.0, 0.01, 2.0) top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000) top_p = default_clip(self.top_p, 1.0, 0.01, 1.0) return (temperature, top_k, top_p) class ChatRequestBody(BaseModel): conversation: TokenizedConversation | None = None input_text: str | None = None input_audio: ChatAudioBytes | None = None assistant_style: AssistantStyle | None = None global_sampler_config: SamplerConfig | None = None local_sampler_config: SamplerConfig | None = None class PresetOptions(BaseModel): options: list[str]