Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from abc import ABC | |
from io import BytesIO | |
from typing import Literal | |
import numpy as np | |
from pydantic import BaseModel, ConfigDict | |
class AbortController(ABC): | |
def is_alive(self) -> bool: | |
raise NotImplementedError | |
class NeverAbortedController(AbortController): | |
def is_alive(self) -> bool: | |
return True | |
def is_none_or_alive(abort_controller: AbortController | None) -> bool: | |
return abort_controller is None or abort_controller.is_alive() | |
class ModelNameResponse(BaseModel): | |
model_name: str | |
class TokenizedMessage(BaseModel): | |
role: Literal["user", "assistant"] | |
content: list[list[int]] | |
"""[audio_channels+1, time_steps]""" | |
def time_steps(self) -> int: | |
return len(self.content[0]) | |
def append(self, chunk: list[list[int]]): | |
assert len(chunk) == len(self.content), "Incompatible chunk length" | |
assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape" | |
for content_channel, chunk_channel in zip(self.content, chunk): | |
content_channel.extend(chunk_channel) | |
class TokenizedConversation(BaseModel): | |
messages: list[TokenizedMessage] | |
def time_steps(self) -> int: | |
return sum(msg.time_steps() for msg in self.messages) | |
def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]": | |
sum_time_steps = 0 | |
selected_messages: list[TokenizedMessage] = [] | |
for msg in reversed(self.messages): | |
cur_time_steps = msg.time_steps() | |
if sum_time_steps + cur_time_steps > max_time_steps: | |
break | |
sum_time_steps += cur_time_steps | |
selected_messages.append(msg) | |
return list(reversed(selected_messages)) | |
class ChatAudioBytes(BaseModel): | |
model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64") | |
sample_rate: int | |
audio_data: bytes | |
""" | |
shape = (channels, samples) or (samples,); | |
dtype = int16 or float32 | |
""" | |
def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes": | |
buf = BytesIO() | |
np.save(buf, audio[1]) | |
return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue()) | |
def to_audio(self) -> tuple[int, np.ndarray]: | |
buf = BytesIO(self.audio_data) | |
audio_np = np.load(buf) | |
return self.sample_rate, audio_np | |
class ChatResponseItem(BaseModel): | |
tokenized_input: TokenizedMessage | None = None | |
token_chunk: list[list[int]] | None = None | |
"""[audio_channels+1, time_steps]""" | |
text_chunk: str | None = None | |
audio_chunk: ChatAudioBytes | None = None | |
end_of_stream: bool | None = None | |
"""Represent Special token <|eostm|>""" | |
end_of_transcription: bool | None = None | |
"""Represent Special token <|eot|> (not <|endoftext|>)""" | |
stop_reason: str | None = None | |
"""The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted""" | |
class AssistantStyle(BaseModel): | |
preset_character: str | None = None | |
custom_character_prompt: str | None = None | |
preset_voice: str | None = None | |
custom_voice: ChatAudioBytes | None = None | |
class SamplerConfig(BaseModel): | |
""" | |
Sampling configuration for text/audio generation. | |
- If some fields are not set, their effects are disabled. | |
- If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined. | |
- Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling. | |
""" | |
temperature: float | None = None | |
top_k: int | None = None | |
top_p: float | None = None | |
def normalized(self) -> tuple[float, int, float]: | |
""" | |
Returns: | |
A tuple (temperature, top_k, top_p) with normalized values. | |
""" | |
if ( | |
(self.temperature is not None and self.temperature <= 0.0) | |
or (self.top_k is not None and self.top_k <= 1) | |
or (self.top_p is not None and self.top_p <= 0.0) | |
): | |
return (1.0, 1, 1.0) | |
def default_clip[T: int | float]( | |
value: T | None, default_value: T, min_value: T, max_value: T | |
) -> T: | |
if value is None: | |
return default_value | |
return max(min(value, max_value), min_value) | |
temperature = default_clip(self.temperature, 1.0, 0.01, 2.0) | |
top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000) | |
top_p = default_clip(self.top_p, 1.0, 0.01, 1.0) | |
return (temperature, top_k, top_p) | |
class ChatRequestBody(BaseModel): | |
conversation: TokenizedConversation | None = None | |
input_text: str | None = None | |
input_audio: ChatAudioBytes | None = None | |
assistant_style: AssistantStyle | None = None | |
global_sampler_config: SamplerConfig | None = None | |
local_sampler_config: SamplerConfig | None = None | |
class PresetOptions(BaseModel): | |
options: list[str] | |