diff --git a/fish_speech/__pycache__/conversation.cpython-310.pyc b/fish_speech/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4dc1336106c5d496e7a1c091e609089eb30d096 Binary files /dev/null and b/fish_speech/__pycache__/conversation.cpython-310.pyc differ diff --git a/fish_speech/__pycache__/scheduler.cpython-310.pyc b/fish_speech/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ce90919af88b3c612722a85c3799f2cc4d58d76 Binary files /dev/null and b/fish_speech/__pycache__/scheduler.cpython-310.pyc differ diff --git a/fish_speech/callbacks/__init__.py b/fish_speech/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbcf3f33656d180ca87cd14a21ede1544e5a61a3 --- /dev/null +++ b/fish_speech/callbacks/__init__.py @@ -0,0 +1,3 @@ +from .grad_norm import GradNormMonitor + +__all__ = ["GradNormMonitor"] diff --git a/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..033bf77b0edc8dbe764c3e4386c005136b1ee50c Binary files /dev/null and b/fish_speech/callbacks/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2058510bf280afba7b92dc027d4629aa030b72fc Binary files /dev/null and b/fish_speech/callbacks/__pycache__/grad_norm.cpython-310.pyc differ diff --git a/fish_speech/callbacks/grad_norm.py b/fish_speech/callbacks/grad_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc95ef2a3723323b2d976001ed1e3c79c00b21a --- /dev/null +++ b/fish_speech/callbacks/grad_norm.py @@ -0,0 +1,113 @@ +from typing import Optional, Union + +import lightning.pytorch as pl +import torch +from lightning import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from torch import Tensor, nn +from torch.utils._foreach_utils import ( + _group_tensors_by_device_and_dtype, + _has_foreach_support, +) + + +@torch.no_grad() +def grad_norm( + parameters: Union[Tensor, list[Tensor]], + norm_type: float = 2.0, +) -> float: + """ + Returns the norm of the gradients of the given parameters. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + norm_type (float): type of the used p-norm. + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ # noqa: E501 + + if isinstance(parameters, Tensor): + parameters = [parameters] + + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return None + + first_device = grads[0].device + grouped_grads: dict[ + tuple[torch.device, torch.dtype], list[list[Tensor]] + ] = _group_tensors_by_device_and_dtype( + [[g.detach() for g in grads]] + ) # type: ignore[assignment] + + norms = [] + for (device, _), ([grads], _) in grouped_grads.items(): + if _has_foreach_support(grads, device=device): + norms.extend(torch._foreach_norm(grads, norm_type)) + else: + norms.extend([torch.norm(g, norm_type) for g in grads]) + + return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + +class GradNormMonitor(Callback): + """ + Callback that computes the gradient norm of the model parameters. + """ + + def __init__( + self, + norm_type: float = 2.0, + logging_interval: str = "step", + sub_module: Optional[Union[str, list[str]]] = None, + ) -> None: + """ + Args: + norm_type (float): type of the used p-norm. + logging_interval (str): "step" or "epoch". + """ + super().__init__() + + self.norm_type = norm_type + self.logging_interval = logging_interval + self.sub_module = sub_module + + def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None: + """ + Computes the gradient norm of the model parameters and logs it to the logger. + + Args: + trainer (Trainer): The trainer object + model (LightningModule): The current lightningModule + """ + + lightning_model = model + + if self.sub_module is None: + return self.log_sub_module_grad_norm(lightning_model, model, "") + + sub_modules = self.sub_module + if isinstance(sub_modules, str): + sub_modules = [sub_modules] + + for sub_module in sub_modules: + self.log_sub_module_grad_norm( + lightning_model, getattr(model, sub_module), f"/{sub_module}" + ) + + def log_sub_module_grad_norm( + self, lightning_model: LightningModule, model: nn.Module, path: str + ) -> None: + grad_norm_val = grad_norm(model.parameters(), self.norm_type) + if grad_norm_val is None: + return + + on_step = self.logging_interval == "step" + lightning_model.log( + f"train{path}/grad_norm", + grad_norm_val, + on_step=on_step, + on_epoch=not on_step, + ) diff --git a/fish_speech/configs/base.yaml b/fish_speech/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99e6dab54d3f57bce4f6d29a9129a19a523cad75 --- /dev/null +++ b/fish_speech/configs/base.yaml @@ -0,0 +1,87 @@ +# Base configuration for training a model +paths: + run_dir: results/${project} + ckpt_dir: ${paths.run_dir}/checkpoints + +hydra: + run: + dir: ${paths.run_dir} + +# Lightning Trainer +trainer: + _target_: lightning.pytorch.trainer.Trainer + + default_root_dir: ${paths.run_dir} + accelerator: gpu + num_nodes: 1 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + process_group_backend: nccl # This should be override when training on windows + + precision: bf16-mixed + + # disable validation by epoch end + check_val_every_n_epoch: null + val_check_interval: 5000 + max_steps: 100_000 + + # Use torch.backends.cudnn.benchmark to speed up training + benchmark: true + +# Callbacks +callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${paths.ckpt_dir} + filename: "step_{step:09d}" + save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save 5 latest checkpoints + monitor: step # use step to monitor checkpoints + mode: max # save the latest checkpoint with the highest global_step + every_n_epochs: null # don't save checkpoints by epoch end + every_n_train_steps: 5000 # save checkpoints every 5000 steps + auto_insert_metric_name: false + + model_summary: + _target_: lightning.pytorch.callbacks.ModelSummary + max_depth: 2 # the maximum depth of layer nesting that the summary will include + + learning_rate_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: step + log_momentum: false + + grad_norm_monitor: + _target_: fish_speech.callbacks.GradNormMonitor + norm_type: 2 + logging_interval: step + +# Logger +logger: + tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.run_dir}/tensorboard/" + name: null + log_graph: false + default_hp_metric: true + prefix: "" + + # wandb: + # _target_: lightning.pytorch.loggers.wandb.WandbLogger + # # name: "" # name of the run (normally generated by wandb) + # save_dir: "${paths.run_dir}" + # offline: False + # id: null # pass correct id to resume experiment! + # anonymous: null # enable anonymous logging + # project: "fish-speech" + # log_model: False # upload lightning ckpts + # prefix: "" # a string to put at the beginning of metric keys + # # entity: "" # set to name of your wandb team + # group: "" + # tags: ["vq", "hq", "finetune"] + # job_type: "" + +# Loop +train: true +test: false diff --git a/fish_speech/configs/firefly_gan_vq.yaml b/fish_speech/configs/firefly_gan_vq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10aa8d4a522f0859ed8f541f5d48672d84b39c8f --- /dev/null +++ b/fish_speech/configs/firefly_gan_vq.yaml @@ -0,0 +1,33 @@ +_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture +spec_transform: + _target_: fish_speech.utils.spectrogram.LogMelSpectrogram + sample_rate: 44100 + n_mels: 160 + n_fft: 2048 + hop_length: 512 + win_length: 2048 +backbone: + _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder + input_channels: 160 + depths: [3, 3, 9, 3] + dims: [128, 256, 384, 512] + drop_path_rate: 0.2 + kernel_size: 7 +head: + _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator + hop_length: 512 + upsample_rates: [8, 8, 2, 2, 2] # aka. strides + upsample_kernel_sizes: [16, 16, 4, 4, 4] + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + num_mels: 512 + upsample_initial_channel: 512 + pre_conv_kernel_size: 13 + post_conv_kernel_size: 13 +quantizer: + _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize + input_dim: 512 + n_groups: 8 + n_codebooks: 1 + levels: [8, 5, 5, 5] + downsample_factor: [2, 2] diff --git a/fish_speech/configs/lora/r_8_alpha_16.yaml b/fish_speech/configs/lora/r_8_alpha_16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aecc4d9766a18fe31c55941e01b1f590c95e77c9 --- /dev/null +++ b/fish_speech/configs/lora/r_8_alpha_16.yaml @@ -0,0 +1,4 @@ +_target_: fish_speech.models.text2semantic.lora.LoraConfig +r: 8 +lora_alpha: 16 +lora_dropout: 0.01 diff --git a/fish_speech/configs/text2semantic_finetune.yaml b/fish_speech/configs/text2semantic_finetune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4c1993023099e122fc9e004bda55ec075ed5e1b --- /dev/null +++ b/fish_speech/configs/text2semantic_finetune.yaml @@ -0,0 +1,83 @@ +defaults: + - base + - _self_ + +project: text2semantic_finetune_dual_ar +max_length: 4096 +pretrained_ckpt_path: checkpoints/fish-speech-1.4 + +# Lightning Trainer +trainer: + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: "norm" + max_steps: 1000 + precision: bf16-true + limit_val_batches: 10 + val_check_interval: 100 + +# Dataset Configuration +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: ${pretrained_ckpt_path} + +# Dataset Configuration +train_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +val_dataset: + _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset + proto_files: + - data/protos + tokenizer: ${tokenizer} + causal: true + max_length: ${max_length} + use_speaker: false + interactive_prob: 0.7 + +data: + _target_: fish_speech.datasets.semantic.SemanticDataModule + train_dataset: ${train_dataset} + val_dataset: ${val_dataset} + num_workers: 4 + batch_size: 8 + tokenizer: ${tokenizer} + max_length: ${max_length} + +# Model Configuration +model: + _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic + model: + _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained + path: ${pretrained_ckpt_path} + load_weights: true + max_length: ${max_length} + lora_config: null + + optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0 + betas: [0.9, 0.95] + eps: 1e-5 + + lr_scheduler: + _target_: torch.optim.lr_scheduler.LambdaLR + _partial_: true + lr_lambda: + _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda + _partial_: true + num_warmup_steps: 10 + +# Callbacks +callbacks: + model_checkpoint: + every_n_train_steps: ${trainer.val_check_interval} diff --git a/fish_speech/conversation.py b/fish_speech/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ca0ef9181754eda7e6b49e01abeafbe07fb00f --- /dev/null +++ b/fish_speech/conversation.py @@ -0,0 +1,2 @@ +SEMANTIC_TOKEN = "<|semantic|>" +CODEBOOK_PAD_TOKEN_ID = 0 diff --git a/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca763c1c12b41234a939ddbe343575b67a79bb92 Binary files /dev/null and b/fish_speech/datasets/__pycache__/semantic.cpython-310.pyc differ diff --git a/fish_speech/datasets/concat_repeat.py b/fish_speech/datasets/concat_repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa596b95a572ee15c5570cbdb792c9a78e62dfa --- /dev/null +++ b/fish_speech/datasets/concat_repeat.py @@ -0,0 +1,53 @@ +import bisect +import random +from typing import Iterable + +from torch.utils.data import Dataset, IterableDataset + + +class ConcatRepeatDataset(Dataset): + datasets: list[Dataset] + cumulative_sizes: list[int] + repeats: list[int] + + @staticmethod + def cumsum(sequence, repeats): + r, s = [], 0 + for dataset, repeat in zip(sequence, repeats): + l = len(dataset) * repeat + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): + super().__init__() + + self.datasets = list(datasets) + self.repeats = repeats + + assert len(self.datasets) > 0, "datasets should not be an empty iterable" + assert len(self.datasets) == len( + repeats + ), "datasets and repeats should have the same length" + + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatRepeatDataset does not support IterableDataset" + + self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + dataset = self.datasets[dataset_idx] + + return dataset[sample_idx % len(dataset)] diff --git a/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b7bb23609e78b12b6e608581f1e8d764bd9db3a Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_pb2.cpython-310.pyc differ diff --git a/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e22635c991e5d669704c3cf95dd528a39b8b822 Binary files /dev/null and b/fish_speech/datasets/protos/__pycache__/text_data_stream.cpython-310.pyc differ diff --git a/fish_speech/datasets/protos/text-data.proto b/fish_speech/datasets/protos/text-data.proto new file mode 100644 index 0000000000000000000000000000000000000000..5eb26d94aa3be1e21066f2bf38c90d54e85a8379 --- /dev/null +++ b/fish_speech/datasets/protos/text-data.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package text_data; + +message Semantics { + repeated uint32 values = 1; +} + +message Sentence { + repeated string texts = 1; + repeated Semantics semantics = 3; +} + +message TextData { + string source = 1; + string name = 2; + repeated Sentence sentences = 4; +} + +message SampledData { + string source = 1; + string name = 2; + repeated Sentence samples = 3; +} diff --git a/fish_speech/datasets/protos/text_data_pb2.py b/fish_speech/datasets/protos/text_data_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..bfce0e8be59fc51e68999ef137e1fd0e4adc0d7e --- /dev/null +++ b/fish_speech/datasets/protos/text_data_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: text-data.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals["_SEMANTICS"]._serialized_start = 30 + _globals["_SEMANTICS"]._serialized_end = 57 + _globals["_SENTENCE"]._serialized_start = 59 + _globals["_SENTENCE"]._serialized_end = 125 + _globals["_TEXTDATA"]._serialized_start = 127 + _globals["_TEXTDATA"]._serialized_end = 207 + _globals["_SAMPLEDDATA"]._serialized_start = 209 + _globals["_SAMPLEDDATA"]._serialized_end = 290 +# @@protoc_insertion_point(module_scope) diff --git a/fish_speech/datasets/protos/text_data_stream.py b/fish_speech/datasets/protos/text_data_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3c25bcd764e8245de47dcdf9686d6adfb5a107 --- /dev/null +++ b/fish_speech/datasets/protos/text_data_stream.py @@ -0,0 +1,36 @@ +import struct + +from .text_data_pb2 import TextData + + +def read_pb_stream(f): + while True: + buf = f.read(4) + if len(buf) == 0: + break + size = struct.unpack("I", buf)[0] + buf = f.read(size) + text_data = TextData() + text_data.ParseFromString(buf) + yield text_data + + +def write_pb_stream(f, text_data): + buf = text_data.SerializeToString() + f.write(struct.pack("I", len(buf))) + f.write(buf) + + +def pack_pb_stream(text_data): + buf = text_data.SerializeToString() + return struct.pack("I", len(buf)) + buf + + +def split_pb_stream(f): + while True: + head = f.read(4) + if len(head) == 0: + break + size = struct.unpack("I", head)[0] + buf = f.read(size) + yield head + buf diff --git a/fish_speech/datasets/semantic.py b/fish_speech/datasets/semantic.py new file mode 100644 index 0000000000000000000000000000000000000000..3c64e01077ae253bdc4e4d9cd948f8fb50df7418 --- /dev/null +++ b/fish_speech/datasets/semantic.py @@ -0,0 +1,496 @@ +import random +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from random import Random +from typing import Optional, Union + +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn.functional as F +from datasets.download.streaming_download_manager import xopen +from huggingface_hub import HfApi +from lightning import LightningDataModule +from torch.distributed import get_rank, get_world_size, is_initialized +from torch.utils.data import DataLoader, IterableDataset, get_worker_info +from transformers import AutoTokenizer + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.datasets.protos.text_data_pb2 import SampledData +from fish_speech.datasets.protos.text_data_stream import read_pb_stream +from fish_speech.text.clean import clean_text +from fish_speech.utils import RankedLogger +from fish_speech.utils.braceexpand import braceexpand + +log = RankedLogger(__name__, rank_zero_only=True) + + +def split_by_rank_worker(files): + # We need to know the total number of devices + # to split the data properly + + total_devices = 1 + if is_initialized(): + total_devices = get_world_size() + + worker_info = get_worker_info() + if worker_info is not None: + total_devices *= worker_info.num_workers + + if len(files) < total_devices: + # Repeat the files N times to match the number of devices + files = files * (total_devices // len(files) + 1) + + # DDP + if is_initialized(): + files = files[get_rank() :: get_world_size()] + + # Split by worker + if worker_info is not None: + files = files[worker_info.id :: worker_info.num_workers] + + return files + + +class AutoTextSemanticInstructionDataset(IterableDataset): + """ + Auto Augment Dataset by Speaker + + 1. Random concatenate multiple sentences from the same speaker to form a longer sentence + 2. Automatically normalize the text + + For interactive mode, we use the following format (multiple sequences): + [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + + For non-interactive mode, we use the following format (one long sequence): + [INST] text [/INST] ... + """ + + def __init__( + self, + proto_files: list[str], + seed: int = 42, + interactive_prob: float = 0.5, + max_length: int = 1024, + tokenizer: AutoTokenizer = None, + use_speaker: bool | float = True, + causal: bool = True, + num_codebooks: Optional[int] = None, + skip_text_prob: float = 0.0, + ): + """ + Args: + proto_files: proto buf files if using local data + seed: random seed + interactive_prob: probability to use interactive mode + max_length: max length of the text + tokenizer: tokenizer + use_speaker: include speaker information in the prompt + causal: use causal sampling when using local data, disable will lead to random sampling + num_codebooks: number of codebooks, if None, it will be automatically detected + skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode + """ + + super().__init__() + + assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]" + + self.seed = seed + self.max_length = max_length + self.tokenizer = tokenizer + self.interactive_prob = interactive_prob + self.use_speaker = use_speaker + self.proto_files = proto_files + self.causal = causal + self.num_codebooks = num_codebooks + self.skip_text_prob = skip_text_prob + + self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>") + self.groups = None + + def init_mock_data_server(self): + if self.groups is not None: + return + + # Expand the proto files + expanded_proto_files = [] + for filename in self.proto_files: + for i in braceexpand(filename): + i = Path(i) + if i.is_file(): + expanded_proto_files.append(i) + elif i.is_dir(): + expanded_proto_files.extend(i.rglob("*.proto")) + expanded_proto_files.extend(i.rglob("*.protos")) + else: + raise ValueError(f"{i} is not a file or directory") + + expanded_proto_files = sorted(expanded_proto_files) + Random(self.seed).shuffle(expanded_proto_files) + + self.groups = [] + shard_proto_files = split_by_rank_worker(expanded_proto_files) + log.info( + f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files" + ) + + count = 0 + for filename in shard_proto_files: + with open(filename, "rb") as f: + for text_data in read_pb_stream(f): + self.groups.append(text_data) + count += 1 + + log.info(f"Read total {count} groups of data") + + # Shuffle the lines + Random(self.seed).shuffle(self.groups) + self.group_weights = [len(i.sentences) for i in self.groups] + + def __iter__(self): + while True: + yield self.augment() + + def tokenize_sentence(self, sentence: str): + sentence = clean_text(sentence) + tokens = self.tokenizer.encode( + f"{sentence}", + max_length=10**6, + add_special_tokens=False, + truncation=False, + ) + return sentence, len(tokens) + + def sample_data(self): + if self.groups is None: + self.init_mock_data_server() + + # Shuffle unique lines, estimate that each sample is at least 20 tokens + num_samples = self.max_length // 20 + + # choice group based on their number of samples + group = random.choices(self.groups, weights=self.group_weights, k=1)[0] + + if self.causal: + # Sample in order + if num_samples >= len(group.sentences): + samples = group.sentences + else: + begin = random.randint(0, len(group.sentences) - num_samples) + samples = group.sentences[begin : begin + num_samples] + else: + samples = random.choices( + group.sentences, k=min(num_samples, len(group.sentences)) + ) + + return SampledData( + source=group.source, + name=group.name, + samples=samples, + ) + + def augment(self): + final_text, final_semantic = [], [] + response = self.sample_data() + if len(response.samples) == 0: + # Invalid group + return None + + samples = list(response.samples) + idx = 0 + use_interactive = random.random() < self.interactive_prob + + if use_interactive is False: + # Random sample based on speaker using a truncated normal distribution + a = torch.tensor([0], dtype=torch.float32) + torch.nn.init.trunc_normal_( + a, + mean=self.max_length // 2, + std=self.max_length // 4, + a=10, + b=self.max_length, + ) + remaining_tokens = a.long().item() - 4 + else: + remaining_tokens = self.max_length + + # Use speaker + if isinstance(self.use_speaker, float): + use_speaker = random.random() < self.use_speaker + else: + use_speaker = self.use_speaker + + all_tokens, all_labels = [], [] + while remaining_tokens > 0 and len(samples) > 0: + sentence = samples.pop(0) + + text = random.choice(sentence.texts) + text, length = self.tokenize_sentence(text) + remaining_tokens -= length + len(sentence.semantics[0].values) + + if use_interactive is False: + final_text.append(text) + final_semantic.append(sentence.semantics) + else: + # For interactive mode, we only apply speaker for the first sentence + # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] + tokens, labels = self.pack_sentences( + sentences=[text], + semantics=[sentence.semantics], + speaker=response.name if use_speaker else None, + skip_text=random.random() < self.skip_text_prob, + ) + + all_tokens.append(tokens) + all_labels.append(labels) + + idx += 1 + + if use_interactive is False: + tokens, labels = self.pack_sentences( + final_text, + semantics=final_semantic, + speaker=response.name if use_speaker else None, + ) + all_tokens.append(tokens) + all_labels.append(labels) + + tokens = torch.cat(all_tokens, dim=1) + labels = torch.cat(all_labels, dim=1) + + # Verify that the length is correct + assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}" + + data = {"tokens": tokens, "labels": labels} + + return data + + def pack_sentences( + self, + sentences: list[str], + semantics: list, + speaker: Optional[str] = None, + skip_text: bool = False, + ): + if speaker is None: + speaker = "assistant" + + cated_sentences = " ".join(sentences) + if skip_text: + cated_sentences = "<|skip_text|>" + + final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>" + final_text = final_text + f"<|im_start|>{speaker}\n" + + encoded = self.tokenizer.encode( + final_text, + add_special_tokens=False, + truncation=False, + max_length=10**6, + ) + semantic_length = sum([len(i[0].values) for i in semantics]) + prompt_length = len(encoded) + num_codebooks = ( + len(semantics[0]) if self.num_codebooks is None else self.num_codebooks + ) + + # Pack the tokens and semantics (add and to semantic tokens) + tokens = ( + encoded + + [self.semantic_token_id] * semantic_length + + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"]) + ) + + # Codebook bos/padding: 0, eos: 1 + codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)] + for segment in semantics: + for book_idx, book in zip(range(num_codebooks), segment): + for j in book.values: + codes[book_idx].append(int(j) + 1) + + for book in codes: + book.extend([CODEBOOK_PAD_TOKEN_ID] * 1) + + tokens = [tokens] + codes + + tokens = torch.tensor(tokens, dtype=torch.long) + labels = tokens.clone() + + if skip_text: + # If text is not provided, the sentence is used for condition only, all labels are -100 + torch.fill_(labels, -100) + return tokens, labels + + # Mask out the tokens for semantic, predict semantic tokens only + # Since we don't mask out the input tokens, the language modeling still works + labels[1:, :prompt_length] = -100 + + tokens = tokens[:, :-1] + labels = labels[:, 1:] + + # Verify the padding is correct, and the last token is eos + assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all() + assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all() + + return tokens, labels + + +@dataclass +class TextDataCollator: + tokenizer: AutoTokenizer + max_length: int = 1024 + + def __call__(self, examples): + if "negative_tokens" in examples: + positive_examples = [] + negative_examples = [] + + for i in examples: + positive_examples.append( + { + "tokens": i["tokens"], + "labels": i["labels"], + } + ) + negative_examples.append( + { + "tokens": i["negative_tokens"], + "labels": i["negative_labels"], + } + ) + + examples = positive_examples + negative_examples + + return self.batchify(examples) + + def batchify(self, examples, tokens_key="tokens", labels_key="labels"): + tokens, attention_masks, labels = [], [], [] + + # Calculate the max length + max_tokens_length = 0 + for example in examples: + max_tokens_length = max(max_tokens_length, example[tokens_key].size(1)) + max_tokens_length = min(max_tokens_length, self.max_length) + + for example in examples: + _tokens = example[tokens_key][:, :max_tokens_length] + _labels = example[labels_key][:, :max_tokens_length] + _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool) + tokens_length = _tokens.size(1) + _attention_mask[:tokens_length] = False + + assert tokens_length == _labels.size( + 1 + ), f"{tokens_length} != {_labels.size(1)}" + + if tokens_length < max_tokens_length: + _tokens = F.pad( + _tokens, + (0, max_tokens_length - tokens_length), + value=self.tokenizer.eos_token_id, + ) + _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID + _labels = F.pad( + _labels, (0, max_tokens_length - _labels.size(1)), value=-100 + ) + + tokens.append(_tokens) + attention_masks.append(_attention_mask) + labels.append(_labels) + + tokens = torch.stack(tokens, dim=0) + attention_masks = torch.stack(attention_masks, dim=0) + labels = torch.stack(labels, dim=0) + + return { + "inputs": tokens, + "attention_masks": attention_masks, + "labels": labels, + } + + +class InterleaveDataset(IterableDataset): + def __init__( + self, + datasets: list[IterableDataset], + probabilities: list[float], + seed: int = 42, + ): + super().__init__() + + self.datasets = datasets + self.probabilities = probabilities + self.seed = seed + + def __iter__(self): + rng = np.random.default_rng(self.seed) + dataset_iterators = [iter(dataset) for dataset in self.datasets] + + while True: + # Random choice one + dataset_idx = rng.choice(len(self.datasets), p=self.probabilities) + dataset_iterator = dataset_iterators[dataset_idx] + + try: + yield next(dataset_iterator) + except StopIteration: + # Exhausted, create a new iterator + dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx]) + yield next(dataset_iterators[dataset_idx]) + + +class SemanticDataModule(LightningDataModule): + def __init__( + self, + train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset], + batch_size: int = 32, + tokenizer: AutoTokenizer = None, + max_length: int = 1024, + num_workers: int = 4, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.tokenizer = tokenizer + self.max_length = max_length + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=TextDataCollator(self.tokenizer, self.max_length), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + from tqdm import tqdm + + ds = AutoTextSemanticInstructionDataset( + ["data/protos"], + tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"), + use_speaker=False, + interactive_prob=1.0, + skip_text_prob=0.5, + ) + + for i in ds: + print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False)) + # i["labels"][0][i["labels"][0] == -100] = 0 + # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False)) + break diff --git a/fish_speech/datasets/vqgan.py b/fish_speech/datasets/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a45583d22efb0feb9dc1e823bae1ef74534b299e --- /dev/null +++ b/fish_speech/datasets/vqgan.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import librosa +import numpy as np +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + +from fish_speech.utils import RankedLogger + +logger = RankedLogger(__name__, rank_zero_only=False) + + +class VQGANDataset(Dataset): + def __init__( + self, + filelist: str, + sample_rate: int = 32000, + hop_length: int = 640, + slice_frames: Optional[int] = None, + ): + super().__init__() + + filelist = Path(filelist) + root = filelist.parent + + self.files = [ + root / line.strip() + for line in filelist.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + self.sample_rate = sample_rate + self.hop_length = hop_length + self.slice_frames = slice_frames + + def __len__(self): + return len(self.files) + + def get_item(self, idx): + file = self.files[idx] + + audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) + + # Slice audio and features + if ( + self.slice_frames is not None + and audio.shape[0] > self.slice_frames * self.hop_length + ): + start = np.random.randint( + 0, audio.shape[0] - self.slice_frames * self.hop_length + ) + audio = audio[start : start + self.slice_frames * self.hop_length] + + if len(audio) == 0: + return None + + max_value = np.abs(audio).max() + if max_value > 1.0: + audio = audio / max_value + + return { + "audio": torch.from_numpy(audio), + } + + def __getitem__(self, idx): + try: + return self.get_item(idx) + except Exception as e: + import traceback + + traceback.print_exc() + logger.error(f"Error loading {self.files[idx]}: {e}") + return None + + +@dataclass +class VQGANCollator: + def __call__(self, batch): + batch = [x for x in batch if x is not None] + + audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) + audio_maxlen = audio_lengths.max() + + # Rounds up to nearest multiple of 2 (audio_lengths) + audios = [] + for x in batch: + audios.append( + torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) + ) + + return { + "audios": torch.stack(audios), + "audio_lengths": audio_lengths, + } + + +class VQGANDataModule(LightningDataModule): + def __init__( + self, + train_dataset: VQGANDataset, + val_dataset: VQGANDataset, + batch_size: int = 32, + num_workers: int = 4, + val_batch_size: Optional[int] = None, + ): + super().__init__() + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + self.batch_size = batch_size + self.val_batch_size = val_batch_size or batch_size + self.num_workers = num_workers + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + shuffle=True, + persistent_workers=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + collate_fn=VQGANCollator(), + num_workers=self.num_workers, + persistent_workers=True, + ) + + +if __name__ == "__main__": + dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") + dataloader = DataLoader( + dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() + ) + + for batch in dataloader: + print(batch["audios"].shape) + print(batch["features"].shape) + print(batch["audio_lengths"]) + print(batch["feature_lengths"]) + break diff --git a/fish_speech/i18n/README.md b/fish_speech/i18n/README.md new file mode 100644 index 0000000000000000000000000000000000000000..700902b09db20911ef1ad678cbdce5644b84aea2 --- /dev/null +++ b/fish_speech/i18n/README.md @@ -0,0 +1,27 @@ +## i18n Folder Attribution + +The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below: + +### fish_speech/i18n/core.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py) + +**Initial commit:** +add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35) + +**Initial author:** +[@L4Ph](https://github.com/L4Ph) + +### fish_speech/i18n/scan.py + +**Related code from RVC:** +[https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py) + +**Initial commit:** +File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058) + +**Initial author:** +[@towzeur](https://github.com/towzeur) + +We appreciate the contributions of the RVC project and its authors. diff --git a/fish_speech/i18n/__init__.py b/fish_speech/i18n/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..981dbb3b3ecf28043ec9ff5757f947182821a246 --- /dev/null +++ b/fish_speech/i18n/__init__.py @@ -0,0 +1,3 @@ +from .core import i18n + +__all__ = ["i18n"] diff --git a/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba5a935b26a69595794d6840da906e6615c3a52f Binary files /dev/null and b/fish_speech/i18n/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/i18n/__pycache__/core.cpython-310.pyc b/fish_speech/i18n/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d2787af00a38bc8ffebb84ed30565a71e94b01 Binary files /dev/null and b/fish_speech/i18n/__pycache__/core.cpython-310.pyc differ diff --git a/fish_speech/i18n/core.py b/fish_speech/i18n/core.py new file mode 100644 index 0000000000000000000000000000000000000000..9f793ec95669228f7f4e8f9a7a5fe38da85c74bd --- /dev/null +++ b/fish_speech/i18n/core.py @@ -0,0 +1,40 @@ +import json +import locale +from pathlib import Path + +I18N_FILE_PATH = Path(__file__).parent / "locale" +DEFAULT_LANGUAGE = "en_US" + + +def load_language_list(language): + with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f: + language_list = json.load(f) + + return language_list + + +class I18nAuto: + def __init__(self): + i18n_file = Path(".locale") + + if i18n_file.exists(): + with open(i18n_file, "r", encoding="utf-8") as f: + language = f.read().strip() + else: + # getlocale can't identify the system's language ((None, None)) + language = locale.getdefaultlocale()[0] + + if (I18N_FILE_PATH / f"{language}.json").exists() is False: + language = DEFAULT_LANGUAGE + + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + + +i18n = I18nAuto() diff --git a/fish_speech/i18n/locale/en_US.json b/fish_speech/i18n/locale/en_US.json new file mode 100644 index 0000000000000000000000000000000000000000..6e280c236e9c79de2087ec33c7bf6f8e1a5296c4 --- /dev/null +++ b/fish_speech/i18n/locale/en_US.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Accumulate Gradient Batches", + "Add to Processing Area": "Add to Processing Area", + "Added path successfully!": "Added path successfully!", + "Advanced Config": "Advanced Config", + "Base LLAMA Model": "Base LLAMA Model", + "Batch Inference": "Batch Inference", + "Batch Size": "Batch Size", + "Changing with the Model Path": "Changing with the Model Path", + "Chinese": "Chinese", + "Compile Model": "Compile Model", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time", + "Copy": "Copy", + "Data Preprocessing": "Data Preprocessing", + "Data Preprocessing Path": "Data Preprocessing Path", + "Data Source": "Data Source", + "Decoder Model Config": "Decoder Model Config", + "Decoder Model Path": "Decoder Model Path", + "Disabled": "Disabled", + "Enable Reference Audio": "Enable Reference Audio", + "English": "English", + "Error Message": "Error Message", + "File Preprocessing": "File Preprocessing", + "Generate": "Generate", + "Generated Audio": "Generated Audio", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format", + "Infer interface is closed": "Infer interface is closed", + "Inference Configuration": "Inference Configuration", + "Inference Server Configuration": "Inference Server Configuration", + "Inference Server Error": "Inference Server Error", + "Inferring interface is launched at {}": "Inferring interface is launched at {}", + "Initial Learning Rate": "Initial Learning Rate", + "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription", + "Input Text": "Input Text", + "Invalid path: {}": "Invalid path: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU", + "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off", + "Japanese": "Japanese", + "LLAMA Configuration": "LLAMA Configuration", + "LLAMA Model Config": "LLAMA Model Config", + "LLAMA Model Path": "LLAMA Model Path", + "Labeling Device": "Labeling Device", + "LoRA Model to be merged": "LoRA Model to be merged", + "Maximum Audio Duration": "Maximum Audio Duration", + "Maximum Length per Sample": "Maximum Length per Sample", + "Maximum Training Steps": "Maximum Training Steps", + "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit", + "Merge": "Merge", + "Merge LoRA": "Merge LoRA", + "Merge successfully": "Merge successfully", + "Minimum Audio Duration": "Minimum Audio Duration", + "Model Output Path": "Model Output Path", + "Model Size": "Model Size", + "Move": "Move", + "Move files successfully": "Move files successfully", + "No audio generated, please check the input text.": "No audio generated, please check the input text.", + "No selected options": "No selected options", + "Number of Workers": "Number of Workers", + "Open Inference Server": "Open Inference Server", + "Open Labeler WebUI": "Open Labeler WebUI", + "Open Tensorboard": "Open Tensorboard", + "Opened labeler in browser": "Opened labeler in browser", + "Optional Label Language": "Optional Label Language", + "Optional online ver": "Optional online ver", + "Output Path": "Output Path", + "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path", + "Precision": "Precision", + "Probability of applying Speaker Condition": "Probability of applying Speaker Condition", + "Put your text here.": "Put your text here.", + "Reference Audio": "Reference Audio", + "Reference Text": "Reference Text", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.", + "Remove Selected Data": "Remove Selected Data", + "Removed path successfully!": "Removed path successfully!", + "Repetition Penalty": "Repetition Penalty", + "Save model every n steps": "Save model every n steps", + "Select LLAMA ckpt": "Select LLAMA ckpt", + "Select VITS ckpt": "Select VITS ckpt", + "Select VQGAN ckpt": "Select VQGAN ckpt", + "Select source file processing method": "Select source file processing method", + "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)", + "Selected: {}": "Selected: {}", + "Speaker": "Speaker", + "Speaker is identified by the folder name": "Speaker is identified by the folder name", + "Start Training": "Start Training", + "Streaming Audio": "Streaming Audio", + "Streaming Generate": "Streaming Generate", + "Tensorboard Host": "Tensorboard Host", + "Tensorboard Log Path": "Tensorboard Log Path", + "Tensorboard Port": "Tensorboard Port", + "Tensorboard interface is closed": "Tensorboard interface is closed", + "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}", + "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.", + "Training Configuration": "Training Configuration", + "Training Error": "Training Error", + "Training stopped": "Training stopped", + "Type name of the speaker": "Type name of the speaker", + "Type the path or select from the dropdown": "Type the path or select from the dropdown", + "Use LoRA": "Use LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model", + "Use filelist": "Use filelist", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G", + "VITS Configuration": "VITS Configuration", + "VQGAN Configuration": "VQGAN Configuration", + "Validation Batch Size": "Validation Batch Size", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.", + "WebUI Host": "WebUI Host", + "WebUI Port": "WebUI Port", + "Whisper Model": "Whisper Model", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU", + "latest": "latest", + "new": "new", + "Realtime Transform Text": "Realtime Transform Text", + "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)", + "Text Normalization": "Text Normalization" +} diff --git a/fish_speech/i18n/locale/es_ES.json b/fish_speech/i18n/locale/es_ES.json new file mode 100644 index 0000000000000000000000000000000000000000..3285341f6893fe3e2ccbee6490dd8c90ed21854e --- /dev/null +++ b/fish_speech/i18n/locale/es_ES.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular lotes de gradientes", + "Add to Processing Area": "Agregar al Área de Procesamiento", + "Added path successfully!": "¡Ruta agregada exitosamente!", + "Advanced Config": "Configuración Avanzada", + "Base LLAMA Model": "Modelo Base LLAMA", + "Batch Inference": "Inferencia por Lote", + "Batch Size": "Tamaño del Lote", + "Changing with the Model Path": "Cambiando con la Ruta del Modelo", + "Chinese": "Chino", + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío", + "Copy": "Copiar", + "Data Preprocessing": "Preprocesamiento de Datos", + "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos", + "Data Source": "Fuente de Datos", + "Decoder Model Config": "Configuración del modelo decodificador", + "Decoder Model Path": "Ruta del modelo decodificador", + "Disabled": "Desactivado", + "Enable Reference Audio": "Habilitar Audio de Referencia", + "English": "Inglés", + "Error Message": "Mensaje de Error", + "File Preprocessing": "Preprocesamiento de Archivos", + "Generate": "Generar", + "Generated Audio": "Audio Generado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab", + "Infer interface is closed": "La interfaz de inferencia está cerrada", + "Inference Configuration": "Configuración de Inferencia", + "Inference Server Configuration": "Configuración del Servidor de Inferencia", + "Inference Server Error": "Error del Servidor de Inferencia", + "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}", + "Initial Learning Rate": "Tasa de Aprendizaje Inicial", + "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Ruta inválida: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU", + "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado", + "Japanese": "Japonés", + "LLAMA Configuration": "Configuración de LLAMA", + "LLAMA Model Config": "Configuración del Modelo LLAMA", + "LLAMA Model Path": "Ruta del Modelo LLAMA", + "Labeling Device": "Dispositivo de Etiquetado", + "LoRA Model to be merged": "Modelo LoRA a fusionar", + "Maximum Audio Duration": "Duración máxima de audio", + "Maximum Length per Sample": "Longitud Máxima por Muestra", + "Maximum Training Steps": "Pasos Máximos de Entrenamiento", + "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite", + "Merge": "Fusionar", + "Merge LoRA": "Fusionar LoRA", + "Merge successfully": "Fusionado exitosamente", + "Minimum Audio Duration": "Duración mínima de audio", + "Model Output Path": "Ruta de Salida del Modelo", + "Model Size": "Tamaño del Modelo", + "Move": "Mover", + "Move files successfully": "Archivos movidos exitosamente", + "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.", + "No selected options": "No hay opciones seleccionadas", + "Number of Workers": "Número de Trabajadores", + "Open Inference Server": "Abrir Servidor de Inferencia", + "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "Se abrió el etiquetador en el navegador", + "Optional Label Language": "Idioma de Etiquetado Opcional", + "Optional online ver": "Ver en línea opcional", + "Output Path": "Ruta de Salida", + "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente", + "Precision": "Precisión", + "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante", + "Put your text here.": "Ponga su texto aquí.", + "Reference Audio": "Audio de Referencia", + "Reference Text": "Texto de Referencia", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.", + "Remove Selected Data": "Eliminar Datos Seleccionados", + "Removed path successfully!": "¡Ruta eliminada exitosamente!", + "Repetition Penalty": "Penalización por Repetición", + "Save model every n steps": "Guardar modelo cada n pasos", + "Select LLAMA ckpt": "Seleccionar punto de control LLAMA", + "Select VITS ckpt": "Seleccionar punto de control VITS", + "Select VQGAN ckpt": "Seleccionar punto de control VQGAN", + "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente", + "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)", + "Selected: {}": "Seleccionado: {}", + "Speaker": "Hablante", + "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta", + "Start Training": "Iniciar Entrenamiento", + "Streaming Audio": "transmisión de audio", + "Streaming Generate": "síntesis en flujo", + "Tensorboard Host": "Host de Tensorboard", + "Tensorboard Log Path": "Ruta de Registro de Tensorboard", + "Tensorboard Port": "Puerto de Tensorboard", + "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada", + "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}", + "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.", + "Training Configuration": "Configuración de Entrenamiento", + "Training Error": "Error de Entrenamiento", + "Training stopped": "Entrenamiento detenido", + "Type name of the speaker": "Escriba el nombre del hablante", + "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo", + "Use filelist": "Usar lista de archivos", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G", + "VITS Configuration": "Configuración de VITS", + "VQGAN Configuration": "Configuración de VQGAN", + "Validation Batch Size": "Tamaño del Lote de Validación", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.", + "WebUI Host": "Host de WebUI", + "WebUI Port": "Puerto de WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+", + "latest": "más reciente", + "new": "nuevo", + "Realtime Transform Text": "Transformación de Texto en Tiempo Real", + "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)", + "Text Normalization": "Normalización de Texto" +} diff --git a/fish_speech/i18n/locale/ja_JP.json b/fish_speech/i18n/locale/ja_JP.json new file mode 100644 index 0000000000000000000000000000000000000000..d30bac7bcdf4f4c65b1f78b4dcf9d705c1d8eb39 --- /dev/null +++ b/fish_speech/i18n/locale/ja_JP.json @@ -0,0 +1,123 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。", + "Accumulate Gradient Batches": "勾配バッチの累積", + "Add to Processing Area": "処理エリアに追加", + "Added path successfully!": "パスの追加に成功しました!", + "Advanced Config": "詳細設定", + "Base LLAMA Model": "基本LLAMAモデル", + "Batch Inference": "バッチ推論", + "Batch Size": "バッチサイズ", + "Changing with the Model Path": "モデルのパスに伴って変化する", + "Chinese": "中国語", + "Compile Model": "モデルのコンパイル", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります", + "Copy": "コピー", + "Data Preprocessing": "データ前処理", + "Data Preprocessing Path": "データ前処理パス", + "Data Source": "データソース", + "Decoder Model Config": "デコーダーモデルの構成", + "Decoder Model Path": "デコーダーモデルのパス", + "Disabled": "無効", + "Enable Reference Audio": "リファレンスオーディオを有効にする", + "English": "英語", + "Error Message": "エラーメッセージ", + "File Preprocessing": "文書前处理", + "Generate": "生成", + "Generated Audio": "生成されたオーディオ", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています", + "Infer interface is closed": "推論インターフェースが閉じられています", + "Inference Configuration": "推論設定", + "Inference Server Configuration": "推論サーバー設定", + "Inference Server Error": "推論サーバーエラー", + "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました", + "Initial Learning Rate": "初期学習率", + "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス", + "Input Text": "入力テキスト", + "Invalid path: {}": "無効なパス: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください", + "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します", + "Japanese": "日本語", + "LLAMA Configuration": "LLAMA設定", + "LLAMA Model Config": "LLAMAモデル設定", + "LLAMA Model Path": "LLAMAモデルパス", + "Labeling Device": "ラベリングデバイス", + "LoRA Model to be merged": "マージするLoRAモデル", + "Maximum Audio Duration": "最大オーディオの長さ", + "Maximum Length per Sample": "サンプルあたりの最大長", + "Maximum Training Steps": "最大トレーニングステップ数", + "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します", + "Merge": "マージ", + "Merge LoRA": "LoRAのマージ", + "Merge successfully": "マージに成功しました", + "Minimum Audio Duration": "最小オーディオの長さ", + "Model Output Path": "モデル出力パス", + "Model Size": "モデルサイズ", + "Move": "移動", + "Move files successfully": "ファイルの移動に成功しました", + "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。", + "No selected options": "選択されたオプションはありません", + "Number of Workers": "ワーカー数", + "Open Inference Server": "推論サーバーを開く", + "Open Labeler WebUI": "ラベラーWebUIを開く", + "Open Tensorboard": "Tensorboardを開く", + "Opened labeler in browser": "ブラウザでラベラーを開きました", + "Optional Label Language": "オプションのラベル言語", + "Optional online ver": "オプションのオンラインバージョン", + "Output Path": "出力パス", + "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください", + "Precision": "精度", + "Probability of applying Speaker Condition": "話者条件を適用する確率", + "Put your text here.": "ここにテキストを入力してください。", + "Reference Audio": "リファレンスオーディオ", + "Reference Text": "リファレンステキスト", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。", + "Remove Selected Data": "選択したデータを削除", + "Removed path successfully!": "パスの削除に成功しました!", + "Repetition Penalty": "反復ペナルティ", + "Save model every n steps": "nステップごとにモデルを保存", + "Select LLAMA ckpt": " LLAMA チェックポイントを選択", + "Select VITS ckpt": "VITS チェックポイントを選択", + "Select VQGAN ckpt": "VQGAN チェックポイントを選択", + "Select source file processing method": "ソースファイルの処理方法を選択", + "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください", + "Selected: {}": "選択済み: {}", + "Speaker": "話者", + "Speaker is identified by the folder name": "話者はフォルダ名で識別されます", + "Start Training": "トレーニング開始", + "Streaming Audio": "ストリーミングオーディオ", + "Streaming Generate": "ストリーミング合成", + "Tensorboard Host": "Tensorboardホスト", + "Tensorboard Log Path": "Tensorboardログパス", + "Tensorboard Port": "Tensorboardポート", + "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています", + "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました", + "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。", + "Training Configuration": "トレーニング設定", + "Training Error": "トレーニングエラー", + "Training stopped": "トレーニングが停止しました", + "Type name of the speaker": "話者の名前を入力", + "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください", + "Use LoRA": "LoRAを使用", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります", + "Use filelist": "ファイルリストを使用", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください", + "VITS Configuration": "VITS の構成", + "VQGAN Configuration": "VQGAN の構成", + "Validation Batch Size": "検証バッチサイズ", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。", + "WebUI Host": "WebUIホスト", + "WebUI Port": "WebUIポート", + "Whisper Model": "Whisperモデル", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします", + "latest": "最新", + "new": "新規", + "Realtime Transform Text": "リアルタイム変換テキスト", + "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)", + "Text Normalization": "テキスト正規化" + +} diff --git a/fish_speech/i18n/locale/pt_BR.json b/fish_speech/i18n/locale/pt_BR.json new file mode 100644 index 0000000000000000000000000000000000000000..385f20272e19053ab9b6cf6463a84c8ece768c68 --- /dev/null +++ b/fish_speech/i18n/locale/pt_BR.json @@ -0,0 +1,133 @@ +{ + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).", + "Accumulate Gradient Batches": "Acumular Lotes de Gradiente", + "Add to Processing Area": "Adicionar à Área de Processamento", + "Added path successfully!": "Caminho adicionado com sucesso!", + "Advanced Config": "Configuração Avançada", + "Base LLAMA Model": "Modelo LLAMA Base", + "Batch Inference": "Inferência em Lote", + "Batch Size": "Tamanho do Lote", + "Changing with the Model Path": "Alterando com o Caminho do Modelo", + + "Compile Model": "Compilar Modelo", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial", + "Copy": "Copiar", + "Data Preprocessing": "Pré-processamento de Dados", + "Data Preprocessing Path": "Caminho de Pré-processamento de Dados", + "Data Source": "Fonte de Dados", + "Decoder Model Config": "Configuração do Modelo Decodificador", + "Decoder Model Path": "Caminho do Modelo Decodificador", + "Disabled": "Desativado", + "Enable Initial Prompt": "Habilitar Prompt Inicial", + "Enable Reference Audio": "Habilitar Áudio de Referência", + "English": "Inglês", + "Japanese": "Japonês", + "Chinese": "Chinês", + "Portuguese": "Português", + "Spanish": "Espanhol", + "Error Message": "Mensagem de Erro", + "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)", + "File Preprocessing": "Pré-processamento de Arquivos", + "Generate": "Gerar", + "Generated Audio": "Áudio Gerado", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)", + "Infer interface is closed": "A interface de inferência foi fechada", + "Inference Configuration": "Configuração de Inferência", + "Inference Server Configuration": "Configuração do Servidor de Inferência", + "Inference Server Error": "Erro do Servidor de Inferência", + "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}", + "Initial Learning Rate": "Taxa de Aprendizagem Inicial", + "Initial Prompt": "Prompt Inicial", + "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.", + "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição", + "Input Text": "Texto de Entrada", + "Invalid path: {}": "Caminho inválido: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU", + "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)", + "LLAMA Configuration": "Configuração do LLAMA", + "LLAMA Model Config": "Configuração do Modelo LLAMA", + "LLAMA Model Path": "Caminho do Modelo LLAMA", + "Labeling Device": "Dispositivo de Rotulagem", + "LoRA Model to be merged": "Modelo LoRA para mesclagem", + "Maximum Length per Sample": "Comprimento Máximo por Amostra", + "Maximum Training Steps": "Etapas Máximas de Treinamento", + "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite", + "Merge": "Mesclar", + "Merge LoRA": "Mesclar LoRA", + "Merge successfully": "Mesclado com sucesso", + "Model Output Path": "Caminho de Saída do Modelo", + "Model Quantization": "Quantização do Modelo", + "Model Size": "Tamanho do Modelo", + "Move": "Mover", + "Move files successfully": "Arquivos movidos com sucesso", + "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.", + "No selected options": "Nenhuma opção selecionada", + "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)", + "Number of Workers": "Número de Processos", + "Open Inference Server": "Abrir Servidor de Inferência", + "Open Labeler WebUI": "Abrir WebUI de Rotulagem", + "Open Tensorboard": "Abrir Tensorboard", + "Opened labeler in browser": "WebUI de rotulagem aberta no navegador", + "Optional Label Language": "Idioma do Rótulo (Opcional)", + "Optional online ver": "Versão online (opcional)", + "Output Path": "Caminho de Saída", + "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente", + "Post-quantification Precision": "Precisão Pós-quantização", + "Precision": "Precisão", + "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador", + "Put your text here.": "Insira seu texto aqui.", + "Quantify": "Quantizar", + "Quantify successfully": "Quantizado com sucesso", + "Realtime Transform Text": "Transformar Texto em Tempo Real", + "Reference Audio": "Áudio de Referência", + "Reference Text": "Texto de Referência", + "warning": "Aviso", + "Pre-processing begins...": "O pré-processamento começou!", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.", + "Remove Selected Data": "Remover Dados Selecionados", + "Removed path successfully!": "Caminho removido com sucesso!", + "Repetition Penalty": "Penalidade de Repetição", + "Save model every n steps": "Salvar modelo a cada n etapas", + "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA", + "Select source file processing method": "Escolha como processar o arquivo de origem", + "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)", + "Selected: {}": "Selecionado: {}", + "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta", + "Start Training": "Iniciar Treinamento", + "Streaming Audio": "Áudio em Streaming", + "Streaming Generate": "Geração em Streaming", + "Tensorboard Host": "Host do Tensorboard", + "Tensorboard Log Path": "Caminho de Log do Tensorboard", + "Tensorboard Port": "Porta do Tensorboard", + "Tensorboard interface is closed": "A interface do Tensorboard está fechada", + "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}", + "Text Normalization": "Normalização de Texto", + "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.", + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.", + "Training Configuration": "Configuração de Treinamento", + "Training Error": "Erro de Treinamento", + "Training stopped": "Treinamento interrompido!", + "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso", + "Use LoRA": "Usar LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade", + "Use filelist": "Usar lista de arquivos", + "VQGAN Configuration": "Configuração do VQGAN", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.", + "WebUI Host": "Host da WebUI", + "WebUI Port": "Porta da WebUI", + "Whisper Model": "Modelo Whisper", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).", + "auto": "automático", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+", + "latest": "mais recente", + "new": "novo", + "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.", + "You don't need to train this model!": "Não é necessário treinar este modelo!", + "Yes": "Sim", + "No": "Não", + "version:": "versão:", + "author:": "autor:" +} diff --git a/fish_speech/i18n/locale/zh_CN.json b/fish_speech/i18n/locale/zh_CN.json new file mode 100644 index 0000000000000000000000000000000000000000..3dd1a5cd1ccf3860ca508238cc64a68ca4fc3276 --- /dev/null +++ b/fish_speech/i18n/locale/zh_CN.json @@ -0,0 +1,122 @@ +{ + "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed", + "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。", + "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.", + "Accumulate Gradient Batches": "梯度累积批次", + "Add to Processing Area": "加入处理区", + "Added path successfully!": "添加路径成功!", + "Advanced Config": "高级参数", + "Base LLAMA Model": "基础 LLAMA 模型", + "Batch Inference": "批量推理", + "Batch Size": "批次大小", + "Changing with the Model Path": "随模型路径变化", + "Chinese": "中文", + "Compile Model": "编译模型", + "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间", + "Copy": "复制", + "Data Preprocessing": "数据预处理", + "Data Preprocessing Path": "数据预处理路径", + "Data Source": "数据源", + "Decoder Model Config": "解码器模型配置", + "Decoder Model Path": "解码器模型路径", + "Disabled": "禁用", + "Enable Reference Audio": "启用参考音频", + "English": "英文", + "Error Message": "错误信息", + "File Preprocessing": "文件预处理", + "Generate": "生成", + "Generated Audio": "音频", + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式", + "Infer interface is closed": "推理界面已关闭", + "Inference Configuration": "推理配置", + "Inference Server Configuration": "推理服务器配置", + "Inference Server Error": "推理服务器错误", + "Inferring interface is launched at {}": "推理界面已在 {} 上启动", + "Initial Learning Rate": "初始学习率", + "Input Audio & Source Path for Transcription": "输入音频和转录源路径", + "Input Text": "输入文本", + "Invalid path: {}": "无效路径: {}", + "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU", + "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭", + "Japanese": "日文", + "LLAMA Configuration": "LLAMA 配置", + "LLAMA Model Config": "LLAMA 模型配置", + "LLAMA Model Path": "LLAMA 模型路径", + "Labeling Device": "标注加速设备", + "LoRA Model to be merged": "要合并的 LoRA 模型", + "Maximum Audio Duration": "最大音频时长", + "Maximum Length per Sample": "每个样本的最大长度", + "Maximum Training Steps": "最大训练步数", + "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制", + "Merge": "合并", + "Merge LoRA": "合并 LoRA", + "Merge successfully": "合并成功", + "Minimum Audio Duration": "最小音频时长", + "Model Output Path": "模型输出路径", + "Model Size": "模型规模", + "Move": "移动", + "Move files successfully": "移动文件成功", + "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.", + "No selected options": "没有选择的选项", + "Number of Workers": "数据加载进程数", + "Open Inference Server": "打开推理服务器", + "Open Labeler WebUI": "打开标注工具", + "Open Tensorboard": "打开 Tensorboard", + "Opened labeler in browser": "在浏览器中打开标注工具", + "Optional Label Language": "[可选] 标注语言", + "Optional online ver": "[可选] 使用在线版", + "Output Path": "输出路径", + "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径", + "Precision": "精度", + "Probability of applying Speaker Condition": "应用说话人条件的概率", + "Put your text here.": "在此处输入文本.", + "Reference Audio": "参考音频", + "Reference Text": "参考文本", + "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.", + "Remove Selected Data": "移除选中数据", + "Removed path successfully!": "移除路径成功!", + "Repetition Penalty": "重复惩罚", + "Save model every n steps": "每 n 步保存模型", + "Select LLAMA ckpt": "选择 LLAMA 检查点", + "Select VITS ckpt": "选择 VITS 检查点", + "Select VQGAN ckpt": "选择 VQGAN 检查点", + "Select source file processing method": "选择源文件处理方法", + "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型", + "Selected: {}": "已选择: {}", + "Speaker": "说话人", + "Speaker is identified by the folder name": "自动根据父目录名称识别说话人", + "Start Training": "开始训练", + "Streaming Audio": "流式音频", + "Streaming Generate": "流式合成", + "Tensorboard Host": "Tensorboard 监听地址", + "Tensorboard Log Path": "Tensorboard 日志路径", + "Tensorboard Port": "Tensorboard 端口", + "Tensorboard interface is closed": "Tensorboard 界面已关闭", + "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动", + "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.", + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.", + "Training Configuration": "训练配置", + "Training Error": "训练错误", + "Training stopped": "训练已停止", + "Type name of the speaker": "输入说话人的名称", + "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择", + "Use LoRA": "使用 LoRA", + "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量", + "Use filelist": "使用文件列表", + "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small", + "VITS Configuration": "VITS 配置", + "VQGAN Configuration": "VQGAN 配置", + "Validation Batch Size": "验证批次大小", + "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)", + "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.", + "WebUI Host": "WebUI 监听地址", + "WebUI Port": "WebUI 端口", + "Whisper Model": "Whisper 模型", + "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.", + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed", + "latest": "最近的检查点", + "new": "创建新的检查点", + "Realtime Transform Text": "实时规范化文本", + "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览", + "Text Normalization": "文本规范化" +} diff --git a/fish_speech/i18n/scan.py b/fish_speech/i18n/scan.py new file mode 100644 index 0000000000000000000000000000000000000000..d0194c0f1a31dc95309c64626d13f04751a44ba1 --- /dev/null +++ b/fish_speech/i18n/scan.py @@ -0,0 +1,122 @@ +import ast +import glob +import json +from collections import OrderedDict +from pathlib import Path + +from loguru import logger + +from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH + + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + + +# scan the directory for all .py files (recursively) +# for each file, parse the code into an AST +# for each AST, extract the i18n strings + +strings = [] +folders = ["fish_speech", "tools"] +# for filename in glob.iglob("**/*.py", recursive=True): +for folder in folders: + for f in Path(folder).rglob("*.py"): + code = f.read_text(encoding="utf-8") + if "i18n(" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + logger.info(f"Found {len(i18n_strings)} i18n strings in {f}") + strings.extend(i18n_strings) + +code_keys = set(strings) +logger.info(f"Total unique: {len(code_keys)}") + + +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) +standard_keys = set(standard_data.keys()) + +# Define the standard file name +unused_keys = standard_keys - code_keys +logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}") +for unused_key in unused_keys: + logger.info(f"\t{unused_key}") + +missing_keys = code_keys - standard_keys +logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}") +for missing_key in missing_keys: + logger.info(f"\t{missing_key}") + +code_keys_dict = OrderedDict() +for s in strings: + code_keys_dict[s] = s + +# write back +with open(standard_file, "w", encoding="utf-8") as f: + json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + +logger.info(f"Updated {standard_file}") + + +# Define the standard file name +standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json" + +# Find all JSON files in the directory +dir_path = I18N_FILE_PATH +languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE] + +# Load the standard file +with open(standard_file, "r", encoding="utf-8") as f: + standard_data = json.load(f, object_pairs_hook=OrderedDict) + +# Loop through each language file +for lang_file in languages: + # Load the language file + with open(lang_file, "r", encoding="utf-8") as f: + lang_data = json.load(f, object_pairs_hook=OrderedDict) + + # Find the difference between the language file and the standard file + diff = set(standard_data.keys()) - set(lang_data.keys()) + + miss = set(lang_data.keys()) - set(standard_data.keys()) + + # Add any missing keys to the language file + for key in diff: + lang_data[key] = "#!" + key + logger.info(f"Added missing key: {key} to {lang_file}") + + # Del any extra keys to the language file + for key in miss: + del lang_data[key] + logger.info(f"Del extra key: {key} from {lang_file}") + + # Sort the keys of the language file to match the order of the standard file + lang_data = OrderedDict( + sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0])) + ) + + # Save the updated language file + with open(lang_file, "w", encoding="utf-8") as f: + json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True) + f.write("\n") + + logger.info(f"Updated {lang_file}") + +logger.info("Done") diff --git a/fish_speech/models/text2semantic/__init__.py b/fish_speech/models/text2semantic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2660e31d27b749e906716f846ad0303f28c5d3ae Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10287acff4182fbf2964bed0c6512b752fe087bc Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lit_module.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c46a1595d473a5422c0f1a526faf162604c66191 Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/llama.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..277545bc846fa08418ba7e846e38c006877bf95d Binary files /dev/null and b/fish_speech/models/text2semantic/__pycache__/lora.cpython-310.pyc differ diff --git a/fish_speech/models/text2semantic/lit_module.py b/fish_speech/models/text2semantic/lit_module.py new file mode 100644 index 0000000000000000000000000000000000000000..df970400f8a073be4c4166a697245fabdf6b09b0 --- /dev/null +++ b/fish_speech/models/text2semantic/lit_module.py @@ -0,0 +1,202 @@ +from typing import Any, Optional + +import lightning as L +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import OptimizerLRScheduler + +import fish_speech.utils as utils +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.models.text2semantic.llama import NaiveTransformer + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +class TextToSemantic(L.LightningModule): + def __init__( + self, + model: NaiveTransformer, + optimizer: Any, + lr_scheduler: Any, + ): + super().__init__() + + self.model = model + self.optimizer_builder = optimizer + self.lr_scheduler_builder = lr_scheduler + + def forward(self, x): + return self.model(x) + + def on_save_checkpoint(self, checkpoint): + # Save only LoRA parameters + state_dict = checkpoint["state_dict"] + use_lora = any("lora" in name for name in state_dict.keys()) + if not use_lora: + return + + for name in list(state_dict.keys()): + if "lora" not in name: + state_dict.pop(name) + + def configure_optimizers(self) -> OptimizerLRScheduler: + # Get weight decay parameters + weight_decay_parameters, other_parameters = [], [] + for name, param in self.named_parameters(): + if ".bias" in name or "norm.weight" in name or ".embeddings." in name: + other_parameters.append(param) + else: + weight_decay_parameters.append(param) + + optimizer = self.optimizer_builder( + [ + {"params": weight_decay_parameters}, + {"params": other_parameters, "weight_decay": 0.0}, + ] + ) + + # Print the parameters and their weight decay + for i in optimizer.param_groups: + log.info( + f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" + ) + + lr_scheduler = self.lr_scheduler_builder(optimizer) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", + }, + } + + # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + assert logits.shape[:-1] == labels.shape + + labels = labels.clone() + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def _step(self, batch, batch_idx, stage: str): + is_train = stage == "train" + + if is_train: + # Key part to make lora work + # Otherwise the parameters are merged, which lead to incorrect gradients + self.model.train() + + # Do positive and negative samples in the same batch to speed up training + labels = batch["labels"] + outputs = self.model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.view(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + ) + + codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.view(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + ) + + loss = base_loss + semantic_loss + + self.log( + f"{stage}/loss", + loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/base_loss", + base_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + self.log( + f"{stage}/semantic_loss", + semantic_loss, + on_step=is_train, + on_epoch=not is_train, + prog_bar=False, + logger=True, + sync_dist=not is_train, + ) + + # Top-5 accuracy + accuracy = self.get_accuracy(codebook_logits, codebook_labels) + self.log( + f"{stage}/top_5_accuracy", + accuracy, + on_step=is_train, + on_epoch=not is_train, + prog_bar=True, + logger=True, + sync_dist=not is_train, + ) + + return loss + + def get_accuracy(self, logits, labels): + mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID) + if mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + _, indices = logits.topk(5, dim=-1) + correct = indices.eq(labels.unsqueeze(-1)) + correct[~mask] = 0 + correct = correct.sum() + accuracy = correct / mask.sum() + + return accuracy + + def training_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx, "val") diff --git a/fish_speech/models/text2semantic/llama.py b/fish_speech/models/text2semantic/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..0725dfb9b78b1154753641b69c959a2faadba48c --- /dev/null +++ b/fish_speech/models/text2semantic/llama.py @@ -0,0 +1,779 @@ +import json +import math +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from loguru import logger +from torch import Tensor +from torch.nn import functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.utils.checkpoint import checkpoint +from transformers import AutoTokenizer + +from fish_speech.conversation import SEMANTIC_TOKEN +from fish_speech.utils import RankedLogger + +from .lora import LoraConfig, setup_lora + +log = RankedLogger(__name__, rank_zero_only=True) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 2048 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + + # Codebook configs + codebook_size: int = 160 + num_codebooks: int = 4 + + # Gradient checkpointing + use_gradient_checkpointing: bool = True + + # Initialize the model + initializer_range: float = 0.02 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @staticmethod + def from_pretrained(path: str): + path = Path(path) + + if path.is_dir(): + path = path / "config.json" + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + match data["model_type"]: + case "naive": + cls = NaiveModelArgs + case "dual_ar": + cls = DualARModelArgs + case _: + raise ValueError(f"Unknown model type: {data['model_type']}") + + return cls(**data) + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +@dataclass +class DualARModelArgs(BaseModelArgs): + model_type: str = "dual_ar" + n_fast_layer: int = 4 + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + codebook_logits: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True + ) -> None: + super().__init__() + self.config = config + self.tokenizer = tokenizer + + self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN) + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.codebook_embeddings = nn.Embedding( + config.codebook_size * config.num_codebooks, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.dim // config.n_head, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + head_dim = self.config.dim // self.config.n_head + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def embed(self, x: Tensor) -> Tensor: + vocab_embeds = [self.embeddings(x[:, 0])] + for i in range(self.config.num_codebooks): + emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size) + emb[x[:, 0] != self.semantic_token_id] = 0 + vocab_embeds.append(emb) + + x = torch.stack(vocab_embeds, dim=3) + x = x.sum(dim=3) + + return x + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(2) + + # Here we want to merge the embeddings of the codebooks + x = self.embed(inp) + + freqs_cis = self.freqs_cis[:seq_len] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) + mask = mask & key_padding_mask[:, None, None, :].logical_not() + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + return_all: bool = False, + ) -> BaseTransformerForwardResult: + # This is used for generation, optimized for torch compile + assert ( + self.max_seq_len != -1 and self.max_batch_size != -1 + ), "Please call setup_caches before forward_generate" + + x = self.embed(x) + + mask = self.causal_mask[ + None, None, input_pos, : self.max_seq_len + ] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=input_pos) + + # If prefill, we only calculate the logits of last token + if x.size(1) > 1 and not return_all: + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @staticmethod + def from_pretrained( + path: str, + load_weights: bool = False, + max_length: int | None = None, + lora_config: LoraConfig | None = None, + rope_base: int | None = None, + ) -> "BaseTransformer": + config = BaseModelArgs.from_pretrained(str(path)) + if max_length is not None: + config.max_seq_len = max_length + log.info(f"Override max_seq_len to {max_length}") + + if rope_base is not None: + config.rope_base = rope_base + log.info(f"Override rope_base to {rope_base}") + + match config.model_type: + case "naive": + model_cls = NaiveTransformer + case "dual_ar": + model_cls = DualARTransformer + case _: + raise ValueError(f"Unknown model type: {config.model_type}") + + tokenizer = AutoTokenizer.from_pretrained(str(path)) + log.info(f"Loading model from {path}, config: {config}") + model = model_cls(config, tokenizer=tokenizer) + + if lora_config is not None: + setup_lora(model, lora_config) + log.info(f"LoRA setup: {lora_config}") + + if load_weights is False: + log.info("Randomly initialized model") + else: + + if "int8" in str(Path(path)): + logger.info("Using int8 weight-only quantization!") + from tools.llama.quantize import WeightOnlyInt8QuantHandler + + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(Path(path)): + logger.info("Using int4 quantization!") + path_comps = path.name.split("-") + assert path_comps[-2].startswith("g") + groupsize = int(path_comps[-2][1:]) + from tools.llama.quantize import WeightOnlyInt4QuantHandler + + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + weights = torch.load( + Path(path) / "model.pth", map_location="cpu", mmap=True + ) + + if "state_dict" in weights: + logger.warning( + "Using a TextToSemantic LightningModule checkpoint, " + "please make sure it is a full model, not a LoRA model." + ) + weights = weights["state_dict"] + + if next(iter(weights.keys())).startswith("model."): + logger.info( + f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys" + ) + new_weights = OrderedDict() + for k, v in weights.items(): + new_weights[k.replace("model.", "")] = v + weights = new_weights + + # Verify the name and shape of parameters since strict=False in load_state_dict. + for k, v in model.named_parameters(): + if k not in weights: + logger.warning(f"No weight for {k}") + elif v.shape != weights[k].shape: + logger.warning( + f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}" + ) + + err = model.load_state_dict(weights, strict=False, assign=True) + log.info(f"Loaded weights with error: {err}") + + return model + + def save_pretrained(self, path: str, drop_lora: bool = False): + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + self.config.save(path / "config.json") + state_dict = self.state_dict() + + if drop_lora: + for key in list(state_dict.keys()): + if "lora" not in key: + continue + + state_dict.pop(key) + log.info(f"Drop LoRA parameter: {key}") + + torch.save(state_dict, path / "model.pth") + self.tokenizer.save_pretrained(path) + + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.codebook_output = nn.Linear( + config.dim, + config.codebook_size * config.num_codebooks, + bias=False, + ) + + self.apply(self._init_weights) + + def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult: + token_logits = result.logits + x = result.hidden_states + + # Codebook + codebook_logits = self.codebook_output(self.codebook_norm(x)) + codebook_logits = rearrange( + codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + ) + return self.decode(result) + + def forward_generate( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> TransformerForwardResult: + result = super().forward_generate(x, input_pos) + return self.decode(result) + + +class DualARTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None: + super().__init__(config, init_weights=False, tokenizer=tokenizer) + + # Fast transformer + self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim) + + # The equivalent bs is so large that sdpa doesn't work + self.fast_layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer) + ) + self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.fast_output = nn.Linear( + config.dim, + config.codebook_size, + bias=False, + ) + + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16 + ): + super().setup_caches(max_batch_size, max_seq_len, dtype) + + head_dim = self.config.dim // self.config.n_head + + # Fast transformer + # The max seq len here is the number of codebooks + for b in self.fast_layers: + b.attention.kv_cache = KVCache( + max_batch_size, + self.config.num_codebooks, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ) + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward(inp, key_padding_mask) + token_logits = parent_result.logits + x = parent_result.hidden_states + + # Fast transformer + fast_seq_len = self.config.num_codebooks + fast_mask = self.causal_mask[ + None, None, :fast_seq_len, :fast_seq_len + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[:fast_seq_len] + + # Drop the last token and rotate left + codebooks = inp[:, 1:-1, 1:] + codebooks = F.pad(codebooks, (0, 1), value=0) + codebook_embeddings = self.fast_embeddings(codebooks) + x = torch.cat([x[:, None], codebook_embeddings], dim=1) + b, s = x.size(0), x.size(2) + x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len + + # Remove padded part + codebooks = rearrange(codebooks, "b n s -> (b s) n") + codebook_mask = (codebooks == 0).all(dim=-1) + + if torch.all(codebook_mask): + # If all codebooks are padded, we keep first 8 to make sure the model runs + codebook_mask[:8] = False + + x_bs, x_len = x.size(0), x.size(1) + x = x[~codebook_mask] + + for layer in self.fast_layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True) + else: + x = layer(x, fast_freqs_cis, fast_mask) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) + codebook_logits = self.fast_output(fast_out) + + # Re-pad the codebook_logits + buffer = torch.zeros( + x_bs, + x_len, + codebook_logits.size(-1), + device=codebook_logits.device, + dtype=codebook_logits.dtype, + ) + buffer[~codebook_mask] = codebook_logits + codebook_logits = buffer + + assert codebook_logits.shape[1] == self.config.num_codebooks + codebook_logits = rearrange( + codebook_logits, + "(b s) n d -> b s n d", + b=b, + s=s, + n=self.config.num_codebooks, + ) + + return TransformerForwardResult( + token_logits=token_logits, + codebook_logits=codebook_logits, + ) + + def forward_generate_fast( + self, x: Tensor, input_pos: Optional[Tensor] = None + ) -> Tensor: + # Fast transformer + x = x.view(1, 1, -1) + + fast_mask = self.causal_mask[ + None, None, input_pos, : self.config.num_codebooks + ] # (B, N, Q, K) + fast_freqs_cis = self.freqs_cis[input_pos] + + for layer in self.fast_layers: + x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos) + + # unflatten the batch and num_codebooks + fast_out = self.fast_norm(x) # only take the last token + codebook_logits = self.fast_output(fast_out) + + return codebook_logits + + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/fish_speech/models/text2semantic/lora.py b/fish_speech/models/text2semantic/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..647ca6fcccf038e17d2cf91a2874281dff3e0938 --- /dev/null +++ b/fish_speech/models/text2semantic/lora.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass + +import loralib as lora + + +@dataclass +class LoraConfig: + r: int + lora_alpha: float + lora_dropout: float = 0.0 + + +def setup_lora(model, lora_config): + # Replace the embedding layer with a LoRA layer + model.embeddings = lora.Embedding( + num_embeddings=model.embeddings.num_embeddings, + embedding_dim=model.embeddings.embedding_dim, + padding_idx=model.embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + model.codebook_embeddings = lora.Embedding( + num_embeddings=model.codebook_embeddings.num_embeddings, + embedding_dim=model.codebook_embeddings.embedding_dim, + padding_idx=model.codebook_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Replace output layer with a LoRA layer + linears = [(model, "output")] + + # Replace all linear layers with LoRA layers + for layer in model.layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + if hasattr(model, "fast_layers"): + model.fast_embeddings = lora.Embedding( + num_embeddings=model.fast_embeddings.num_embeddings, + embedding_dim=model.fast_embeddings.embedding_dim, + padding_idx=model.fast_embeddings.padding_idx, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + ) + + # Dual-AR model + linears.append((model, "fast_output")) + + for layer in model.fast_layers: + linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) + linears.extend( + [ + (layer.feed_forward, "w1"), + (layer.feed_forward, "w2"), + (layer.feed_forward, "w3"), + ] + ) + + for module, layer in linears: + updated_linear = lora.Linear( + in_features=getattr(module, layer).in_features, + out_features=getattr(module, layer).out_features, + bias=getattr(module, layer).bias, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + ) + setattr(module, layer, updated_linear) + + # Mark only the LoRA layers as trainable + lora.mark_only_lora_as_trainable(model, bias="none") + + +def get_merged_state_dict(model): + # This line will merge the state dict of the model and the LoRA parameters + model.eval() + + # Then we need to remove the LoRA parameters from the state dict + state_dict = model.state_dict() + for name in list(state_dict.keys()): + if "lora" in name: + state_dict.pop(name) + + return state_dict diff --git a/fish_speech/models/vqgan/__init__.py b/fish_speech/models/vqgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7370a6672b015f38616e92542abc71ddeeb7a87e Binary files /dev/null and b/fish_speech/models/vqgan/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588bdb4e0f0f6fc5f9838713164c8ed4158b3303 Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/firefly.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22aab32a8842e848cceb650a0a9274c4402bfddb Binary files /dev/null and b/fish_speech/models/vqgan/modules/__pycache__/fsq.cpython-310.pyc differ diff --git a/fish_speech/models/vqgan/modules/firefly.py b/fish_speech/models/vqgan/modules/firefly.py new file mode 100644 index 0000000000000000000000000000000000000000..aa21839b544174d5d91378c5daf8fe1b376a154a --- /dev/null +++ b/fish_speech/models/vqgan/modules/firefly.py @@ -0,0 +1,596 @@ +import math +from functools import partial +from math import prod +from typing import Callable + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.checkpoint import checkpoint + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv1D") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +def unpad1d(x: torch.Tensor, paddings: tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "zeros", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right + before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class FishConvNet(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1 + ): + super(FishConvNet, self).__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + ) + self.stride = stride + self.kernel_size = (kernel_size - 1) * dilation + 1 + self.dilation = dilation + + def forward(self, x): + pad = self.kernel_size - self.stride + extra_padding = get_extra_padding_for_conv1d( + x, self.kernel_size, self.stride, pad + ) + x = pad1d(x, (pad, extra_padding), mode="constant", value=0) + return self.conv(x).contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class FishTransConvNet(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1): + super(FishTransConvNet, self).__init__() + self.conv = nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride=stride, dilation=dilation + ) + self.stride = stride + self.kernel_size = kernel_size + + def forward(self, x): + x = self.conv(x) + pad = self.kernel_size - self.stride + padding_right = math.ceil(pad) + padding_left = pad - padding_right + x = unpad1d(x, (padding_left, padding_right)) + return x.contiguous() + + def weight_norm(self, name="weight", dim=0): + self.conv = weight_norm(self.conv, name=name, dim=dim) + return self + + def remove_weight_norm(self): + self.conv = remove_parametrizations(self.conv) + return self + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[0] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[1] + ).weight_norm(), + FishConvNet( + channels, channels, kernel_size, stride=1, dilation=dilation[2] + ).weight_norm(), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.silu(x) + xt = c1(xt) + xt = F.silu(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_parametrizations(self): + for conv in self.convs1: + remove_parametrizations(conv, tensor_name="weight") + for conv in self.convs2: + remove_parametrizations(conv, tensor_name="weight") + + +class ParallelBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_sizes: tuple[int] = (3, 7, 11), + dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + ): + super().__init__() + + assert len(kernel_sizes) == len(dilation_sizes) + + self.blocks = nn.ModuleList() + for k, d in zip(kernel_sizes, dilation_sizes): + self.blocks.append(ResBlock1(channels, k, d)) + + def forward(self, x): + return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) + + def remove_parametrizations(self): + for block in self.blocks: + block.remove_parametrizations() + + +class HiFiGANGenerator(nn.Module): + def __init__( + self, + *, + hop_length: int = 512, + upsample_rates: tuple[int] = (8, 8, 2, 2, 2), + upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), + resblock_kernel_sizes: tuple[int] = (3, 7, 11), + resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + num_mels: int = 128, + upsample_initial_channel: int = 512, + pre_conv_kernel_size: int = 7, + post_conv_kernel_size: int = 7, + post_activation: Callable = partial(nn.SiLU, inplace=True), + ): + super().__init__() + + assert ( + prod(upsample_rates) == hop_length + ), f"hop_length must be {prod(upsample_rates)}" + + self.conv_pre = FishConvNet( + num_mels, + upsample_initial_channel, + pre_conv_kernel_size, + stride=1, + ).weight_norm() + + self.num_upsamples = len(upsample_rates) + self.num_kernels = len(resblock_kernel_sizes) + + self.noise_convs = nn.ModuleList() + self.ups = nn.ModuleList() + + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + FishTransConvNet( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + stride=u, + ).weight_norm() + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.resblocks.append( + ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) + ) + + self.activation_post = post_activation() + self.conv_post = FishConvNet( + ch, 1, post_conv_kernel_size, stride=1 + ).weight_norm() + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.silu(x, inplace=True) + x = self.ups[i](x) + + if self.training and self.checkpointing: + x = checkpoint( + self.resblocks[i], + x, + use_reentrant=False, + ) + else: + x = self.resblocks[i](x) + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_parametrizations(self): + for up in self.ups: + remove_parametrizations(up, tensor_name="weight") + for block in self.resblocks: + block.remove_parametrizations() + remove_parametrizations(self.conv_pre, tensor_name="weight") + remove_parametrizations(self.conv_post, tensor_name="weight") + + +# DropPath copied from timm library +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ # noqa: E501 + + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ # noqa: E501 + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None] * x + self.bias[:, None] + return x + + +# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + kernel_size (int): Kernel size for depthwise conv. Default: 7. + dilation (int): Dilation for depthwise conv. Default: 1. + """ # noqa: E501 + + def __init__( + self, + dim: int, + drop_path: float = 0.0, + layer_scale_init_value: float = 1e-6, + mlp_ratio: float = 4.0, + kernel_size: int = 7, + dilation: int = 1, + ): + super().__init__() + + self.dwconv = FishConvNet( + dim, + dim, + kernel_size=kernel_size, + # padding=int(dilation * (kernel_size - 1) / 2), + groups=dim, + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, int(mlp_ratio * dim) + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x, apply_residual: bool = True): + input = x + + x = self.dwconv(x) + x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma is not None: + x = self.gamma * x + + x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) + x = self.drop_path(x) + + if apply_residual: + x = input + x + + return x + + +class ConvNeXtEncoder(nn.Module): + def __init__( + self, + input_channels: int = 3, + depths: list[int] = [3, 3, 9, 3], + dims: list[int] = [96, 192, 384, 768], + drop_path_rate: float = 0.0, + layer_scale_init_value: float = 1e-6, + kernel_size: int = 7, + ): + super().__init__() + assert len(depths) == len(dims) + + self.downsample_layers = nn.ModuleList() + stem = nn.Sequential( + FishConvNet( + input_channels, + dims[0], + kernel_size=7, + # padding=3, + # padding_mode="replicate", + # padding_mode="zeros", + ), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + + for i in range(len(depths) - 1): + mid_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), + ) + self.downsample_layers.append(mid_layer) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + cur = 0 + for i in range(len(depths)): + stage = nn.Sequential( + *[ + ConvNeXtBlock( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + kernel_size=kernel_size, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + for i in range(len(self.downsample_layers)): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + + return self.norm(x) + + +class FireflyArchitecture(nn.Module): + def __init__( + self, + backbone: nn.Module, + head: nn.Module, + quantizer: nn.Module, + spec_transform: nn.Module, + ): + super().__init__() + + self.backbone = backbone + self.head = head + self.quantizer = quantizer + self.spec_transform = spec_transform + self.downsample_factor = math.prod(self.quantizer.downsample_factor) + + def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor: + if self.spec_transform is not None: + x = self.spec_transform(x) + + x = self.backbone(x) + if mask is not None: + x = x * mask + + if self.quantizer is not None: + vq_result = self.quantizer(x) + x = vq_result.z + + if mask is not None: + x = x * mask + + x = self.head(x, template=template) + + if x.ndim == 2: + x = x[:, None, :] + + if self.vq is not None: + return x, vq_result + + return x + + def encode(self, audios, audio_lengths): + audios = audios.float() + + mels = self.spec_transform(audios) + mel_lengths = audio_lengths // self.spec_transform.hop_length + mel_masks = sequence_mask(mel_lengths, mels.shape[2]) + mel_masks_float_conv = mel_masks[:, None, :].float() + mels = mels * mel_masks_float_conv + + # Encode + encoded_features = self.backbone(mels) * mel_masks_float_conv + feature_lengths = mel_lengths // self.downsample_factor + + return self.quantizer.encode(encoded_features), feature_lengths + + def decode(self, indices, feature_lengths) -> torch.Tensor: + mel_masks = sequence_mask( + feature_lengths * self.downsample_factor, + indices.shape[2] * self.downsample_factor, + ) + mel_masks_float_conv = mel_masks[:, None, :].float() + audio_lengths = ( + feature_lengths * self.downsample_factor * self.spec_transform.hop_length + ) + + audio_masks = sequence_mask( + audio_lengths, + indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length, + ) + audio_masks_float_conv = audio_masks[:, None, :].float() + + z = self.quantizer.decode(indices) * mel_masks_float_conv + x = self.head(z) * audio_masks_float_conv + + return x, audio_lengths + + def remove_parametrizations(self): + if hasattr(self.backbone, "remove_parametrizations"): + self.backbone.remove_parametrizations() + + if hasattr(self.head, "remove_parametrizations"): + self.head.remove_parametrizations() + + @property + def device(self): + return next(self.parameters()).device diff --git a/fish_speech/models/vqgan/modules/fsq.py b/fish_speech/models/vqgan/modules/fsq.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea4853376b6e663404ff48d6c6b5f664dde4094 --- /dev/null +++ b/fish_speech/models/vqgan/modules/fsq.py @@ -0,0 +1,116 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from vector_quantize_pytorch import GroupedResidualFSQ + +from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet + + +@dataclass +class FSQResult: + z: torch.Tensor + codes: torch.Tensor + latents: torch.Tensor + + +class DownsampleFiniteScalarQuantize(nn.Module): + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + n_groups: int = 1, + levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 + downsample_factor: tuple[int] = (2, 2), + downsample_dims: tuple[int] | None = None, + ): + super().__init__() + + if downsample_dims is None: + downsample_dims = [input_dim for _ in range(len(downsample_factor))] + + all_dims = (input_dim,) + tuple(downsample_dims) + + self.residual_fsq = GroupedResidualFSQ( + dim=all_dims[-1], + levels=levels, + num_quantizers=n_codebooks, + groups=n_groups, + ) + + self.downsample_factor = downsample_factor + self.downsample_dims = downsample_dims + + self.downsample = nn.Sequential( + *[ + nn.Sequential( + FishConvNet( + all_dims[idx], + all_dims[idx + 1], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx + 1]), + ) + for idx, factor in enumerate(downsample_factor) + ] + ) + + self.upsample = nn.Sequential( + *[ + nn.Sequential( + FishTransConvNet( + all_dims[idx + 1], + all_dims[idx], + kernel_size=factor, + stride=factor, + ), + ConvNeXtBlock(dim=all_dims[idx]), + ) + for idx, factor in reversed(list(enumerate(downsample_factor))) + ] + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, z) -> FSQResult: + original_shape = z.shape + z = self.downsample(z) + quantized, indices = self.residual_fsq(z.mT) + result = FSQResult( + z=quantized.mT, + codes=indices.mT, + latents=z, + ) + result.z = self.upsample(result.z) + + # Pad or crop z to match original shape + diff = original_shape[-1] - result.z.shape[-1] + left = diff // 2 + right = diff - left + + if diff > 0: + result.z = F.pad(result.z, (left, right)) + elif diff < 0: + result.z = result.z[..., left:-right] + + return result + + def encode(self, z): + z = self.downsample(z) + _, indices = self.residual_fsq(z.mT) + indices = rearrange(indices, "g b l r -> b (g r) l") + return indices + + def decode(self, indices: torch.Tensor): + indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) + z_q = self.residual_fsq.get_output_from_indices(indices) + z_q = self.upsample(z_q.mT) + return z_q diff --git a/fish_speech/models/vqgan/utils.py b/fish_speech/models/vqgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b90c131d214006875476a161cdfd2dffa8949dac --- /dev/null +++ b/fish_speech/models/vqgan/utils.py @@ -0,0 +1,94 @@ +import matplotlib +import torch +from matplotlib import pyplot as plt + +matplotlib.use("Agg") + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def plot_mel(data, titles=None): + fig, axes = plt.subplots(len(data), 1, squeeze=False) + + if titles is None: + titles = [None for i in range(len(data))] + + plt.tight_layout() + + for i in range(len(data)): + mel = data[i] + + if isinstance(mel, torch.Tensor): + mel = mel.float().detach().cpu().numpy() + + axes[i][0].imshow(mel, origin="lower") + axes[i][0].set_aspect(2.5, adjustable="box") + axes[i][0].set_ylim(0, mel.shape[0]) + axes[i][0].set_title(titles[i], fontsize="medium") + axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) + axes[i][0].set_anchor("W") + + return fig + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(in_act, n_channels): + n_channels_int = n_channels[0] + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + + return acts + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() diff --git a/fish_speech/scheduler.py b/fish_speech/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..43bed6a2210723a7d5e1ea0a48ba61140047ca29 --- /dev/null +++ b/fish_speech/scheduler.py @@ -0,0 +1,40 @@ +import math + + +def get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int, + num_cycles: float = 0.5, + final_lr_ratio: float = 0.0, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + + return max( + final_lr_ratio, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + +def get_constant_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int | float, + num_training_steps: int | None = None, +): + if 0 < num_warmup_steps < 1: # float mode + num_warmup_steps = int(num_warmup_steps * num_training_steps) + + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + return 1.0 diff --git a/fish_speech/text/__init__.py b/fish_speech/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d740bd8eed447d162e55b165965dec17130377ce --- /dev/null +++ b/fish_speech/text/__init__.py @@ -0,0 +1,4 @@ +from .clean import clean_text +from .spliter import split_text + +__all__ = ["clean_text", "split_text"] diff --git a/fish_speech/text/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbda0e48251bdcc53c332c821e4ea9519047d490 Binary files /dev/null and b/fish_speech/text/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/text/__pycache__/clean.cpython-310.pyc b/fish_speech/text/__pycache__/clean.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c648bb945e8d4ff16146dff98a70a779dab7eb Binary files /dev/null and b/fish_speech/text/__pycache__/clean.cpython-310.pyc differ diff --git a/fish_speech/text/__pycache__/spliter.cpython-310.pyc b/fish_speech/text/__pycache__/spliter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94114179529badf1ecd0fa37c19d5fdc6223dcf9 Binary files /dev/null and b/fish_speech/text/__pycache__/spliter.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/.gitignore b/fish_speech/text/chn_text_norm/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..75ea58fa4a7bf34fc9ab35afee24684aa6ef4c89 --- /dev/null +++ b/fish_speech/text/chn_text_norm/.gitignore @@ -0,0 +1,114 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# JetBrains PyCharm +.idea + +# Customize +references +url.txt + +# Git +.git diff --git a/fish_speech/text/chn_text_norm/README.md b/fish_speech/text/chn_text_norm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8450a2c6c0f8e40f4509f5be196eb9f9d2b9afb6 --- /dev/null +++ b/fish_speech/text/chn_text_norm/README.md @@ -0,0 +1,36 @@ +# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works. + +# Chn Text Norm + +this is a repository for chinese text normalization (no longer maintained). + +## Quick Start ## + +### Git Clone Repo ### + +git clone this repo to the root directory of your project which need to use it. + + cd /path/to/proj + git clone https://github.com/Joee1995/chn-text-norm.git + +after that, your doc tree should be: +``` +proj # root of your project +|--- chn_text_norm # this chn-text-norm tool + |--- text.py + |--- ... +|--- text_normalize.py # your text normalization code +|--- ... +``` + +### How to Use ? ### + + # text_normalize.py + from chn_text_norm.text import * + + raw_text = 'your raw text' + text = Text(raw_text=raw_text).normalize() + +### How to add quantums ### + +打开test.py,然后你就知道怎么做了。 diff --git a/fish_speech/text/chn_text_norm/__init__.py b/fish_speech/text/chn_text_norm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ff30c1d86436d172d82a2afe4f2914407b2056 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf1f70fdddf166b1e08a881f44651789cde5665b Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_class.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba0d65d52f907c05d8672ab5ffeba1ef69b0a58 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_constant.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..565e2baec31a08a1d40e473dbaf8cc068c4b56eb Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/basic_util.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac369bcff904eeb22fd4c359b7ef4d0dff2856b Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/cardinal.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb55a304422a219e00687fc987b4cdcfd8283dcd Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/date.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57bcee45c211f05e127b6556c6c3e0dc05a43e9c Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/digit.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2982394aae51d7fef63f6ce4c13444662ccde1af Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/fraction.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cdaa0642dce2713d356ae35a4dea9955df67e9c Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/money.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1572f267a79a5231149fc5bba14cb4b4d4907895 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/percentage.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b088af763cadeb63f0ee4308c56032c19da3ed1f Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/telephone.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f84f49e3880c7410255e8ab1038eeb94d5375656 Binary files /dev/null and b/fish_speech/text/chn_text_norm/__pycache__/text.cpython-310.pyc differ diff --git a/fish_speech/text/chn_text_norm/basic_class.py b/fish_speech/text/chn_text_norm/basic_class.py new file mode 100644 index 0000000000000000000000000000000000000000..58d8f8eb7fc85d0861f106667d8f4e3e52b54761 --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_class.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +"""基本类 +中文字符类 +中文数字/数位类 +中文数字类 +中文数位类 +中文数字系统类 +中文数学符号类 +*中文其他符号类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES + + +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit( + power=index + 1, + simplified=value[0], + traditional=value[1], + big_s=value[1], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), + simplified=value[0], + traditional=value[1], + big_s=value[0], + big_t=value[1], + ) + else: + raise ValueError( + "Counting type should be in {0} ({1} provided).".format( + NUMBERING_TYPES, numbering_type + ) + ) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__( + self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None + ): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v diff --git a/fish_speech/text/chn_text_norm/basic_constant.py b/fish_speech/text/chn_text_norm/basic_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..9a65991b9a9d349a0571c80508633951e52749ef --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_constant.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""基本常量 +中文数字/数位/符号字符常量 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] diff --git a/fish_speech/text/chn_text_norm/basic_util.py b/fish_speech/text/chn_text_norm/basic_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dbf6130be87f285eed9998186508ea489d3bac9e --- /dev/null +++ b/fish_speech/text/chn_text_norm/basic_util.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +"""基本方法 +创建中文数字系统 方法 +中文字符串 <=> 数字串 方法 +数字串 <=> 中文字符串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-02" + +from fish_speech.text.chn_text_norm.basic_class import * +from fish_speech.text.chn_text_norm.basic_constant import * + + +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + larger_units = [ + CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units) + ] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL, + ) + smaller_units = [ + CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units) + ] + # digis + chinese_digis = zip( + CHINESE_DIGIS, + CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, + BIG_CHINESE_DIGIS_TRADITIONAL, + ) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [ + d.traditional, + d.simplified, + d.big_s, + d.big_t, + d.alt_s, + d.alt_t, + ]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [ + get_symbol(c, system) for c in dec_string + ] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance( + integer_symbols[-2], CNU + ): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None) + ) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if ( + isinstance(result[-i - 1], CNU) + and result[-i - 1].power < current_unit.power + ): + result[-i - 1] = CNU( + result[-i - 1].power + current_unit.power, + None, + None, + None, + None, + ) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next( + u for u in reversed(system.units) if u.power < len(striped_string) + ) + result_string = value_string[: -result_unit.power] + return ( + get_value(result_string) + + [result_unit] + + get_value(striped_string[-result_unit.power :]) + ) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string) + ) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND( + 2, + system.digits[2].alt_s, + system.digits[2].alt_t, + system.digits[2].big_s, + system.digits[2].big_t, + ) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = ( + result_symbols[i + 1] if i < len(result_symbols) - 1 else None + ) + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance( + previous_symbol, (CNU, type(None)) + ): + if next_symbol.power != 1 and ( + (previous_symbol is None) or (previous_symbol.power != 1) + ): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s + ) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s + ) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] + in [ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], + ] + and result[0] + in [ + CHINESE_DIGIS[1], + BIG_CHINESE_DIGIS_SIMPLIFIED[1], + BIG_CHINESE_DIGIS_TRADITIONAL[1], + ] + ): + result = result[1:] + + return result + + +if __name__ == "__main__": + + # 测试程序 + all_chinese_number_string = ( + CHINESE_DIGIS + + BIG_CHINESE_DIGIS_SIMPLIFIED + + BIG_CHINESE_DIGIS_TRADITIONAL + + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL + + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED + + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL + + ZERO_ALT + + ONE_ALT + + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT) + ) + + print("num:", chn2num("一万零四百零三点八零五")) + print("num:", chn2num("一亿六点三")) + print("num:", chn2num("一亿零六点三")) + print("num:", chn2num("两千零一亿六点三")) + # print('num:', chn2num('一零零八六')) + print("txt:", num2chn("10260.03", alt_zero=True)) + print("txt:", num2chn("20037.090", numbering_type="low", traditional=True)) + print("txt:", num2chn("100860001.77", numbering_type="high", big=True)) + print( + "txt:", + num2chn( + "059523810880", + alt_one=True, + alt_two=False, + use_lzeros=True, + use_rzeros=True, + use_units=False, + ), + ) + + print(all_chinese_number_string) diff --git a/fish_speech/text/chn_text_norm/cardinal.py b/fish_speech/text/chn_text_norm/cardinal.py new file mode 100644 index 0000000000000000000000000000000000000000..ace9f5ad8e7f3be3a8e41b11dc0b9f80db799616 --- /dev/null +++ b/fish_speech/text/chn_text_norm/cardinal.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""CARDINAL类 (包含小数DECIMAL类) +纯数 <=> 中文字符串 方法 +中文字符串 <=> 纯数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +if __name__ == "__main__": + + # 测试程序 + print(Cardinal(cardinal="21357.230").cardinal2chntext()) diff --git a/fish_speech/text/chn_text_norm/date.py b/fish_speech/text/chn_text_norm/date.py new file mode 100644 index 0000000000000000000000000000000000000000..77acfdb9a91df0fe3c615a0784f61aad87fbe56e --- /dev/null +++ b/fish_speech/text/chn_text_norm/date.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +"""DATE类 +日期 <=> 中文字符串 方法 +中文字符串 <=> 日期 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-07" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.digit import Digit + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", maxsplit=1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", maxsplit=1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Date(date="09年3月16日").date2chntext()) diff --git a/fish_speech/text/chn_text_norm/digit.py b/fish_speech/text/chn_text_norm/digit.py new file mode 100644 index 0000000000000000000000000000000000000000..47c0cd4ad0c700635f84470bfdacfbdafb4a6185 --- /dev/null +++ b/fish_speech/text/chn_text_norm/digit.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""DIGIT类 +数字串 <=> 中文字符串 方法 +中文字符串 <=> 数字串 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +if __name__ == "__main__": + + # 测试程序 + print(Digit(digit="2016").digit2chntext()) diff --git a/fish_speech/text/chn_text_norm/fraction.py b/fish_speech/text/chn_text_norm/fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b6a7feb634d346d59a2b4ab84b77ac88df103 --- /dev/null +++ b/fish_speech/text/chn_text_norm/fraction.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +"""FRACTION类 +分数 <=> 中文字符串 方法 +中文字符串 <=> 分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +if __name__ == "__main__": + + # 测试程序 + print(Fraction(fraction="2135/7230").fraction2chntext()) + print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction()) diff --git a/fish_speech/text/chn_text_norm/money.py b/fish_speech/text/chn_text_norm/money.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c980d32134e1460e96e5bcbcc73d0d55974d2a --- /dev/null +++ b/fish_speech/text/chn_text_norm/money.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""MONEY类 +金钱 <=> 中文字符串 方法 +中文字符串 <=> 金钱 方法 +""" +import re + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-08" + +from fish_speech.text.chn_text_norm.cardinal import Cardinal + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext() + ) + self.chntext = money + return self.chntext + + +if __name__ == "__main__": + + # 测试 + print(Money(money="21.5万元").money2chntext()) + print(Money(money="230块5毛").money2chntext()) diff --git a/fish_speech/text/chn_text_norm/percentage.py b/fish_speech/text/chn_text_norm/percentage.py new file mode 100644 index 0000000000000000000000000000000000000000..46abbf545af62eb951d8f6fe40bcf684587f81b0 --- /dev/null +++ b/fish_speech/text/chn_text_norm/percentage.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""PERCENTAGE类 +百分数 <=> 中文字符串 方法 +中文字符串 <=> 百分数 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-06" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +if __name__ == "__main__": + + # 测试程序 + print(Percentage(chntext="百分之五十六点零三").chntext2percentage()) + print(Percentage(percentage="65.3%").percentage2chntext()) diff --git a/fish_speech/text/chn_text_norm/telephone.py b/fish_speech/text/chn_text_norm/telephone.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b546db628a3b807dc6235b59b188cae3153ff --- /dev/null +++ b/fish_speech/text/chn_text_norm/telephone.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""TELEPHONE类 +电话号码 <=> 中文字符串 方法 +中文字符串 <=> 电话号码 方法 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +from fish_speech.text.chn_text_norm.basic_util import * + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join( + [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] + ) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +if __name__ == "__main__": + + # 测试程序 + print(TelePhone(telephone="0595-23980880").telephone2chntext()) + # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone()) diff --git a/fish_speech/text/chn_text_norm/text.py b/fish_speech/text/chn_text_norm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..54086fd933c01e14c3c55cee9adb52eefb58fd31 --- /dev/null +++ b/fish_speech/text/chn_text_norm/text.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +TEXT类 +""" + +__author__ = "Zhiyang Zhou " +__data__ = "2019-05-03" + +import re + +from fish_speech.text.chn_text_norm.cardinal import Cardinal +from fish_speech.text.chn_text_norm.date import Date +from fish_speech.text.chn_text_norm.digit import Digit +from fish_speech.text.chn_text_norm.fraction import Fraction +from fish_speech.text.chn_text_norm.money import Money +from fish_speech.text.chn_text_norm.percentage import Percentage +from fish_speech.text.chn_text_norm.telephone import TelePhone + +CURRENCY_NAMES = ( + "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" + "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +) +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)" +) + + +class Text: + """ + Text类 + """ + + def __init__(self, raw_text, norm_text=None): + self.raw_text = "^" + raw_text + "$" + self.norm_text = norm_text + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self): + text = self.raw_text + + # 规范化日期 + pattern = re.compile( + r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile( + r"\D+((\d+(\.\d+)?)[多余几]?" + + CURRENCY_UNITS + + "(\d" + + CURRENCY_UNITS + + "?)?)" + ) + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace( + matcher[0], Money(money=matcher[0]).money2chntext(), 1 + ) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace( + matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 + ) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace( + matcher[0], + TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), + 1, + ) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace( + matcher, Fraction(fraction=matcher).fraction2chntext(), 1 + ) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace( + matcher[0], + Percentage(percentage=matcher[0]).percentage2chntext(), + 1, + ) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace( + matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1 + ) + + self.norm_text = text + self._particular() + + return self.norm_text.lstrip("^").rstrip("$") + + +if __name__ == "__main__": + + # 测试程序 + print(Text(raw_text="固话:0595-23865596或23880880。").normalize()) + print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize()) + print(Text(raw_text="分数:32477/76391。").normalize()) + print(Text(raw_text="百分数:80.03%。").normalize()) + print(Text(raw_text="编号:31520181154418。").normalize()) + print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize()) + print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize()) + print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize()) + print(Text(raw_text="特殊:O2O或B2C。").normalize()) diff --git a/fish_speech/text/clean.py b/fish_speech/text/clean.py new file mode 100644 index 0000000000000000000000000000000000000000..c228dfcd13324e8b1abe4ead5f01f4bd8ed0c33a --- /dev/null +++ b/fish_speech/text/clean.py @@ -0,0 +1,31 @@ +import re + +SYMBOLS_MAPPING = { + "“": "'", + "”": "'", + "‘": "'", + "’": "'", + "【": "", + "】": "", + "[": "", + "]": "", + "(": "", + ")": "", + "(": "", + ")": "", + "・": "·", +} + +REPLACE_SYMBOL_REGEX = re.compile( + "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys()) +) + + +def clean_text(text): + # Clean the text + text = text.strip() + + # Replace all chinese symbols with their english counterparts + text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text) + + return text diff --git a/fish_speech/text/spliter.py b/fish_speech/text/spliter.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bb995487c4f53818c6b2a16cf0a886b4e02e84 --- /dev/null +++ b/fish_speech/text/spliter.py @@ -0,0 +1,130 @@ +import re +import string + +from fish_speech.text.clean import clean_text + + +def utf_8_len(text): + return len(text.encode("utf-8")) + + +def break_text(texts, length, splits: set): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if char in splits: + yield curr + curr = "" + + if curr: + yield curr + + +def break_text_by_length(texts, length): + for text in texts: + if utf_8_len(text) <= length: + yield text + continue + + curr = "" + for char in text: + curr += char + + if utf_8_len(curr) >= length: + yield curr + curr = "" + + if curr: + yield curr + + +def add_cleaned(curr, segments): + curr = curr.strip() + if curr and not all(c.isspace() or c in string.punctuation for c in curr): + segments.append(curr) + + +def protect_float(text): + # Turns 3.14 into <3_f_14> to prevent splitting + return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text) + + +def unprotect_float(text): + # Turns <3_f_14> into 3.14 + return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text) + + +def split_text(text, length): + text = clean_text(text) + + # Break the text into pieces with following rules: + # 1. Split the text at ".", "!", "?" if text is NOT a float + # 2. If the text is longer than length, split at "," + # 3. If the text is still longer than length, split at " " + # 4. If the text is still longer than length, split at any character to length + + texts = [text] + texts = map(protect_float, texts) + texts = break_text(texts, length, {".", "!", "?", "。", "!", "?"}) + texts = map(unprotect_float, texts) + texts = break_text(texts, length, {",", ","}) + texts = break_text(texts, length, {" "}) + texts = list(break_text_by_length(texts, length)) + + # Then, merge the texts into segments with length <= length + segments = [] + curr = "" + + for text in texts: + if utf_8_len(curr) + utf_8_len(text) <= length: + curr += text + else: + add_cleaned(curr, segments) + curr = text + + if curr: + add_cleaned(curr, segments) + + return segments + + +if __name__ == "__main__": + # Test the split_text function + + text = "This is a test sentence. This is another test sentence. And a third one." + + assert split_text(text, 50) == [ + "This is a test sentence.", + "This is another test sentence. And a third one.", + ] + assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"] + assert split_text(" ", 10) == [] + assert split_text("a", 10) == ["a"] + + text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines." + assert split_text(text, 50) == [ + "This is a test sentence with only commas,", + "and no dots, and no exclamation marks,", + "and no question marks, and no newlines.", + ] + + text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence." + # First half split at " ", second half split at "," + assert split_text(text, 50) == [ + "This is a test sentence This is a test sentence", + "This is a test sentence. This is a test sentence,", + "This is a test sentence, This is a test sentence.", + ] + + text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。" + assert split_text(text, 50) == [ + "这是一段很长的中文文本,", + "而且没有句号,也没有感叹号,", + "也没有问号,也没有换行符.", + ] diff --git a/fish_speech/train.py b/fish_speech/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e693f3adc4dda787bdd587aec29f53355f2b1653 --- /dev/null +++ b/fish_speech/train.py @@ -0,0 +1,141 @@ +import os + +os.environ["USE_LIBUV"] = "0" +import sys +from typing import Optional + +import hydra +import lightning as L +import pyrootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy +from omegaconf import DictConfig, OmegaConf + +os.environ.pop("SLURM_NTASKS", None) +os.environ.pop("SLURM_JOB_NAME", None) +os.environ.pop("SLURM_NTASKS_PER_NODE", None) + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# Allow TF32 on Ampere GPUs +torch.set_float32_matmul_precision("high") +torch.backends.cudnn.allow_tf32 = True + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + +import fish_speech.utils as utils + +log = utils.RankedLogger(__name__, rank_zero_only=True) + + +@utils.task_wrapper +def train(cfg: DictConfig) -> tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ # noqa: E501 + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=False) + + if cfg.get("deterministic"): + torch.use_deterministic_algorithms(True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + + ckpt_path = cfg.get("ckpt_path") + auto_resume = False + + resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir) + if resume_ckpt_path is not None: + ckpt_path = resume_ckpt_path + auto_resume = True + + if ckpt_path is not None: + log.info(f"Resuming from checkpoint: {ckpt_path}") + + # resume weights only is disabled for auto-resume + if cfg.get("resume_weights_only") and auto_resume is False: + log.info("Resuming weights only!") + ckpt = torch.load(ckpt_path, map_location=model.device) + if "state_dict" in ckpt: + ckpt = ckpt["state_dict"] + err = model.load_state_dict(ckpt, strict=False) + log.info(f"Error loading state dict: {err}") + ckpt_path = None + + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = cfg.get("ckpt_path") + + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main( + version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml" +) +def main(cfg: DictConfig) -> Optional[float]: + # train the model + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/fish_speech/utils/__init__.py b/fish_speech/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05378519dbd18361c639e33413d011e7307c9adb --- /dev/null +++ b/fish_speech/utils/__init__.py @@ -0,0 +1,23 @@ +from .braceexpand import braceexpand +from .context import autocast_exclude_mps +from .file import get_latest_checkpoint +from .instantiators import instantiate_callbacks, instantiate_loggers +from .logger import RankedLogger +from .logging_utils import log_hyperparameters +from .rich_utils import enforce_tags, print_config_tree +from .utils import extras, get_metric_value, task_wrapper + +__all__ = [ + "enforce_tags", + "extras", + "get_metric_value", + "RankedLogger", + "instantiate_callbacks", + "instantiate_loggers", + "log_hyperparameters", + "print_config_tree", + "task_wrapper", + "braceexpand", + "get_latest_checkpoint", + "autocast_exclude_mps", +] diff --git a/fish_speech/utils/__pycache__/__init__.cpython-310.pyc b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1275a8478b5b6c8ca96cd20f18be4f300e5fba8d Binary files /dev/null and b/fish_speech/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..611e658e7387832fb9d481e775466a60689e364c Binary files /dev/null and b/fish_speech/utils/__pycache__/braceexpand.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/context.cpython-310.pyc b/fish_speech/utils/__pycache__/context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0701855f15ea618e6fca6bba156a480a26e06705 Binary files /dev/null and b/fish_speech/utils/__pycache__/context.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/file.cpython-310.pyc b/fish_speech/utils/__pycache__/file.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d52787c4c9346fa3ac90012057d87598170b1619 Binary files /dev/null and b/fish_speech/utils/__pycache__/file.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c1b17fb8f7e05a50ed4056b404d7e60c2f104f Binary files /dev/null and b/fish_speech/utils/__pycache__/instantiators.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/logger.cpython-310.pyc b/fish_speech/utils/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32cfb48f1bac889f58a4059ebc3033b2ec328077 Binary files /dev/null and b/fish_speech/utils/__pycache__/logger.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e24723cdd60dd27d036e1e3a72def349e22f5d8 Binary files /dev/null and b/fish_speech/utils/__pycache__/logging_utils.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8a99c143231037868f6d6593d101636c0955844 Binary files /dev/null and b/fish_speech/utils/__pycache__/rich_utils.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d83d0c6d63e8a397e659056c7f4ecdc3299f9135 Binary files /dev/null and b/fish_speech/utils/__pycache__/spectrogram.cpython-310.pyc differ diff --git a/fish_speech/utils/__pycache__/utils.cpython-310.pyc b/fish_speech/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbad6b1a0fbbd0e58817cd597ae6b9ed26f7e53a Binary files /dev/null and b/fish_speech/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/fish_speech/utils/braceexpand.py b/fish_speech/utils/braceexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ac739f01f7e10e039c68c1157d6c761064f974 --- /dev/null +++ b/fish_speech/utils/braceexpand.py @@ -0,0 +1,217 @@ +""" +Bash-style brace expansion +Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py +License: MIT +""" + +import re +import string +from itertools import chain, product +from typing import Iterable, Iterator, Optional + +__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"] + + +class UnbalancedBracesError(ValueError): + pass + + +alphabet = string.ascii_uppercase + string.ascii_lowercase + +int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$") +char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$") +escape_re = re.compile(r"\\(.)") + + +def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]: + """braceexpand(pattern) -> iterator over generated strings + + Returns an iterator over the strings resulting from brace expansion + of pattern. This function implements Brace Expansion as described in + bash(1), with the following limitations: + + * A pattern containing unbalanced braces will raise an + UnbalancedBracesError exception. In bash, unbalanced braces will either + be partly expanded or ignored. + + * A mixed-case character range like '{Z..a}' or '{a..Z}' will not + include the characters '[]^_`' between 'Z' and 'a'. + + When escape is True (the default), characters in pattern can be + prefixed with a backslash to cause them not to be interpreted as + special characters for brace expansion (such as '{', '}', ','). + To pass through a a literal backslash, double it ('\\\\'). + + When escape is False, backslashes in pattern have no special + meaning and will be preserved in the output. + + Examples: + + >>> from braceexpand import braceexpand + + # Integer range + >>> list(braceexpand('item{1..3}')) + ['item1', 'item2', 'item3'] + + # Character range + >>> list(braceexpand('{a..c}')) + ['a', 'b', 'c'] + + # Sequence + >>> list(braceexpand('index.html{,.backup}')) + ['index.html', 'index.html.backup'] + + # Nested patterns + >>> list(braceexpand('python{2.{5..7},3.{2,3}}')) + ['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3'] + + # Prefixing an integer with zero causes all numbers to be padded to + # the same width. + >>> list(braceexpand('{07..10}')) + ['07', '08', '09', '10'] + + # An optional increment can be specified for ranges. + >>> list(braceexpand('{a..g..2}')) + ['a', 'c', 'e', 'g'] + + # Ranges can go in both directions. + >>> list(braceexpand('{4..1}')) + ['4', '3', '2', '1'] + + # Numbers can be negative + >>> list(braceexpand('{2..-1}')) + ['2', '1', '0', '-1'] + + # Unbalanced braces raise an exception. + >>> list(braceexpand('{1{2,3}')) + Traceback (most recent call last): + ... + UnbalancedBracesError: Unbalanced braces: '{1{2,3}' + + # By default, the backslash is the escape character. + >>> list(braceexpand(r'{1\\{2,3}')) + ['1{2', '3'] + + # Setting 'escape' to False disables backslash escaping. + >>> list(braceexpand(r'\\{1,2}', escape=False)) + ['\\\\1', '\\\\2'] + + """ + return ( + escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape) + ) + + +def parse_pattern(pattern: str, escape: bool) -> Iterator[str]: + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'pattern:', pattern + while pos < len(pattern): + if escape and pattern[pos] == "\\": + pos += 2 + continue + elif pattern[pos] == "{": + if bracketdepth == 0 and pos > start: + # print 'literal:', pattern[start:pos] + items.append([pattern[start:pos]]) + start = pos + bracketdepth += 1 + elif pattern[pos] == "}": + bracketdepth -= 1 + if bracketdepth == 0: + # print 'expression:', pattern[start+1:pos] + expr = pattern[start + 1 : pos] + item = parse_expression(expr, escape) + if item is None: # not a range or sequence + items.extend([["{"], parse_pattern(expr, escape), ["}"]]) + else: + items.append(item) + start = pos + 1 # skip the closing brace + pos += 1 + + if bracketdepth != 0: # unbalanced braces + raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern) + + if start < pos: + items.append([pattern[start:]]) + + return ("".join(item) for item in product(*items)) + + +def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]: + int_range_match = int_range_re.match(expr) + if int_range_match: + return make_int_range(*int_range_match.groups()) + + char_range_match = char_range_re.match(expr) + if char_range_match: + return make_char_range(*char_range_match.groups()) + + return parse_sequence(expr, escape) + + +def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]: + # sequence -> chain(*sequence_items) + start = 0 + pos = 0 + bracketdepth = 0 + items: list[Iterable[str]] = [] + + # print 'sequence:', seq + while pos < len(seq): + if escape and seq[pos] == "\\": + pos += 2 + continue + elif seq[pos] == "{": + bracketdepth += 1 + elif seq[pos] == "}": + bracketdepth -= 1 + elif seq[pos] == "," and bracketdepth == 0: + items.append(parse_pattern(seq[start:pos], escape)) + start = pos + 1 # skip the comma + pos += 1 + + if bracketdepth != 0: + raise UnbalancedBracesError + if not items: + return None + + # part after the last comma (may be the empty string) + items.append(parse_pattern(seq[start:], escape)) + return chain(*items) + + +def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]: + if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]): + padding = max(len(left), len(right)) + else: + padding = 0 + step = (int(incr) or 1) if incr else 1 + start = int(left) + end = int(right) + r = range(start, end + 1, step) if start < end else range(start, end - 1, -step) + fmt = "%0{}d".format(padding) + return (fmt % i for i in r) + + +def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str: + step = (int(incr) or 1) if incr else 1 + start = alphabet.index(left) + end = alphabet.index(right) + if start < end: + return alphabet[start : end + 1 : step] + else: + end = end or -len(alphabet) + return alphabet[start : end - 1 : -step] + + +if __name__ == "__main__": + import doctest + import sys + + failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL) + if failed: + sys.exit(1) diff --git a/fish_speech/utils/context.py b/fish_speech/utils/context.py new file mode 100644 index 0000000000000000000000000000000000000000..f04a99290ab32f7fe5b60656075a2d03af8468d6 --- /dev/null +++ b/fish_speech/utils/context.py @@ -0,0 +1,13 @@ +from contextlib import nullcontext + +import torch + + +def autocast_exclude_mps( + device_type: str, dtype: torch.dtype +) -> nullcontext | torch.autocast: + return ( + nullcontext() + if torch.backends.mps.is_available() + else torch.autocast(device_type, dtype) + ) diff --git a/fish_speech/utils/file.py b/fish_speech/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..78c82640a963fa556657107729f7543d2e7c3510 --- /dev/null +++ b/fish_speech/utils/file.py @@ -0,0 +1,16 @@ +import os +from pathlib import Path + + +def get_latest_checkpoint(path: Path | str) -> Path | None: + # Find the latest checkpoint + ckpt_dir = Path(path) + + if ckpt_dir.exists() is False: + return None + + ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime) + if len(ckpts) == 0: + return None + + return ckpts[-1] diff --git a/fish_speech/utils/instantiators.py b/fish_speech/utils/instantiators.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ee463924f588a35477937fbe3c3364043bdf3e --- /dev/null +++ b/fish_speech/utils/instantiators.py @@ -0,0 +1,50 @@ +from typing import List + +import hydra +from omegaconf import DictConfig +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger + +from .logger import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/fish_speech/utils/logger.py b/fish_speech/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..94f94f738d1d87404354d086c30ef0ad9ab04cdc --- /dev/null +++ b/fish_speech/utils/logger.py @@ -0,0 +1,55 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = True, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/fish_speech/utils/logging_utils.py b/fish_speech/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3b0a2519e12845f09e5fbe86dfccbf5b345429 --- /dev/null +++ b/fish_speech/utils/logging_utils.py @@ -0,0 +1,48 @@ +from lightning.pytorch.utilities import rank_zero_only + +from fish_speech.utils import logger as log + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/fish_speech/utils/rich_utils.py b/fish_speech/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a465f54d610779766d51e3d1a020a3b1517fd1f --- /dev/null +++ b/fish_speech/utils/rich_utils.py @@ -0,0 +1,100 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from fish_speech.utils import logger as log + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ # noqa: E501 + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. " + + f"Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501 + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/fish_speech/utils/spectrogram.py b/fish_speech/utils/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..01c3d7a2ab0f707ae92dbde0feb173927720c841 --- /dev/null +++ b/fish_speech/utils/spectrogram.py @@ -0,0 +1,122 @@ +import torch +import torchaudio.functional as F +from torch import Tensor, nn +from torchaudio.transforms import MelScale + + +class LinearSpectrogram(nn.Module): + def __init__( + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", + ): + super().__init__() + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.mode = mode + + self.register_buffer("window", torch.hann_window(win_length), persistent=False) + + def forward(self, y: Tensor) -> Tensor: + if y.ndim == 3: + y = y.squeeze(1) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + (self.win_length - self.hop_length) // 2, + (self.win_length - self.hop_length + 1) // 2, + ), + mode="reflect", + ).squeeze(1) + + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + + if self.mode == "pow2_sqrt": + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + return spec + + +class LogMelSpectrogram(nn.Module): + def __init__( + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max or float(sample_rate // 2) + + self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center) + + fb = F.melscale_fbanks( + n_freqs=self.n_fft // 2 + 1, + f_min=self.f_min, + f_max=self.f_max, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer( + "fb", + fb, + persistent=False, + ) + + def compress(self, x: Tensor) -> Tensor: + return torch.log(torch.clamp(x, min=1e-5)) + + def decompress(self, x: Tensor) -> Tensor: + return torch.exp(x) + + def apply_mel_scale(self, x: Tensor) -> Tensor: + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) + + def forward( + self, x: Tensor, return_linear: bool = False, sample_rate: int = None + ) -> Tensor: + if sample_rate is not None and sample_rate != self.sample_rate: + x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate) + + linear = self.spectrogram(x) + x = self.apply_mel_scale(linear) + x = self.compress(x) + + if return_linear: + return x, self.compress(linear) + + return x diff --git a/fish_speech/utils/utils.py b/fish_speech/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c546bfa1eddd2ac6bf484cce1ec06da1d33fb121 --- /dev/null +++ b/fish_speech/utils/utils.py @@ -0,0 +1,114 @@ +import warnings +from importlib.util import find_spec +from typing import Callable + +from omegaconf import DictConfig + +from .logger import RankedLogger +from .rich_utils import enforce_tags, print_config_tree + +log = RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: + + ... + + return metric_dict, object_dict + ``` + """ # noqa: E501 + + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or + # cause out-of-memory errors so when using hparam search + # plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.run_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd0b8af3ea645c95065dbbe9b037384e54ad614 Binary files /dev/null and b/fish_speech/webui/__pycache__/launch_utils.cpython-310.pyc differ diff --git a/fish_speech/webui/css/style.css b/fish_speech/webui/css/style.css new file mode 100644 index 0000000000000000000000000000000000000000..3c7a22ecc31881a65a76369b0fd889330a0874c7 --- /dev/null +++ b/fish_speech/webui/css/style.css @@ -0,0 +1,161 @@ +:root { + --my-200: #80eeee; + --my-50: #ecfdf5; + --water-width: 300px; + --water-heigh: 300px; +} + + +/* general styled components */ +.tools { + align-items: center; + justify-content: center; +} + +.gradio-button { + max-width: 2.2em; + min-width: 2.2em !important; + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; + +} + +.gradio-button.secondary-down, .gradio-button.secondary-down:hover{ + box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; +} + +/* replace original footer with ours */ +a{ + font-weight: bold; + cursor: pointer; + color: #030C14 !important; +} + +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} + +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + +/*@keyframes moveBackground {*/ +/* 0% {*/ +/* background-position: 0 0;*/ +/* }*/ +/* 100% {*/ +/* background-position: -100px 100px;*/ +/* }*/ +/*}*/ +@keyframes moveJellyBackground { + 0% { + background-position: 0% 50%; + } + 50% { + background-position: 100% 50%; + } + 100% { + background-position: 0% 50%; + } +} + +.gradio-container { + position: absolute; + z-index: 10; +} + + +.quan { + position: absolute; + bottom: 0; + width: var(--water-width); + height: var(--water-heigh); + border-radius: 0; + /*border: 3px solid rgb(246, 247, 248);*/ + /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/ + z-index: 0; + +} + +.quan:last-child { + margin-right: 0; +} + +.shui { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgb(23, 106, 201); + border-radius: 0; + overflow: hidden; + z-index: 0; +} + +.shui::after { + + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 40%; + background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%); + animation: shi 5s linear infinite; +} + +@keyframes shi { + 0% { + transform: translate(-50%, -65%) rotate(0deg); + } + 100% { + transform: translate(-50%, -65%) rotate(360deg); + } +} + +.shui::before { + content: ''; + position: absolute; + top: 20%; + left: 50%; + width: 150%; + height: 150%; + border-radius: 42%; + background-color: rgb(240, 228, 228, 0.2); + animation: xu 7s linear infinite; +} + +@keyframes xu { + 0% { + transform: translate(-50%, -60%) rotate(0deg); + } + 100% { + transform: translate(-50%, -60%) rotate(360deg); + } +} + +fieldset.data_src div.wrap label { + background: #f8bffee0 !important; +} + +.scrollable-component { + max-height: 100px; + overflow-y: auto; +} + +#file_accordion { + max-height: 220px !important; +} diff --git a/fish_speech/webui/html/footer.html b/fish_speech/webui/html/footer.html new file mode 100644 index 0000000000000000000000000000000000000000..ac1745aa6f41f86a17e3d95564c2bf7a8d7bb615 --- /dev/null +++ b/fish_speech/webui/html/footer.html @@ -0,0 +1,11 @@ +
+ API +  •  + Github +  •  + Gradio +
+
+
+{versions} +
diff --git a/fish_speech/webui/js/animate.js b/fish_speech/webui/js/animate.js new file mode 100644 index 0000000000000000000000000000000000000000..0637a541a8e704632a42b89bdf1471b26e7bb868 --- /dev/null +++ b/fish_speech/webui/js/animate.js @@ -0,0 +1,69 @@ + +function createGradioAnimation() { + const params = new URLSearchParams(window.location.search); + if (!params.has('__theme')) { + params.set('__theme', 'light'); + window.location.search = params.toString(); + } + + var gradioApp = document.querySelector('gradio-app'); + if (gradioApp) { + + document.documentElement.style.setProperty('--my-200', '#80eeee'); + document.documentElement.style.setProperty('--my-50', '#ecfdf5'); + + // gradioApp.style.position = 'relative'; + // gradioApp.style.backgroundSize = '200% 200%'; + // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite'; + // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)'; + // gradioApp.style.display = 'flex'; + // gradioApp.style.justifyContent = 'flex-start'; + // gradioApp.style.flexWrap = 'nowrap'; + // gradioApp.style.overflowX = 'auto'; + + // for (let i = 0; i < 6; i++) { + // var quan = document.createElement('div'); + // quan.className = 'quan'; + // gradioApp.insertBefore(quan, gradioApp.firstChild); + // quan.id = 'quan' + i.toString(); + // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')'; + // var quanContainer = document.querySelector('.quan'); + // if (quanContainer) { + // var shui = document.createElement('div'); + // shui.className = 'shui'; + // quanContainer.insertBefore(shui, quanContainer.firstChild) + // } + // } + } + + var container = document.createElement('div'); + container.id = 'gradio-animation'; + container.style.fontSize = '2em'; + container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace'; + container.style.fontWeight = 'bold'; + container.style.textAlign = 'center'; + container.style.marginBottom = '20px'; + + var text = 'Welcome to Fish-Speech!'; + for (var i = 0; i < text.length; i++) { + (function(i){ + setTimeout(function(){ + var letter = document.createElement('span'); + letter.style.opacity = '0'; + letter.style.transition = 'opacity 0.5s'; + letter.innerText = text[i]; + + container.appendChild(letter); + + setTimeout(function() { + letter.style.opacity = '1'; + }, 50); + }, i * 200); + })(i); + } + + var gradioContainer = document.querySelector('.gradio-container'); + gradioContainer.insertBefore(container, gradioContainer.firstChild); + + return 'Animation created'; +} diff --git a/fish_speech/webui/launch_utils.py b/fish_speech/webui/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f57b595a20177800dbedd71faef573ee8398418 --- /dev/null +++ b/fish_speech/webui/launch_utils.py @@ -0,0 +1,120 @@ +import importlib.util +import os +import subprocess +import sys +from functools import lru_cache +from pathlib import Path +from typing import Iterable + +import gradio as gr +from gradio.themes.base import Base +from gradio.themes.utils import colors, fonts, sizes + +GIT = ( + (Path(os.environ.get("GIT_HOME", "")) / "git").resolve() + if sys.platform == "win32" + else "git" +) +GIT = str(GIT) + + +def is_module_installed(module_name: str) -> bool: + spec = importlib.util.find_spec(module_name) + return spec is not None + + +@lru_cache() +def commit_hash(): + try: + return subprocess.check_output( + [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8" + ).strip() + except Exception: + return "" + + +def versions_html(): + import torch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = commit_hash() + hash = commit.strip("'").split(" ")[0] + + return f""" +version: {hash} + •  +python: {python_version} + •  +torch: {getattr(torch, '__long_version__',torch.__version__)} + •  +gradio: {gr.__version__} + •  +author: fishaudio +""" + + +def version_check(commit): + try: + import requests + + commits = requests.get( + "https://api.github.com/repos/fishaudio/fish-speech/branches/main" + ).json() + if commit != "" and commits["commit"]["sha"] != commit: + print("--------------------------------------------------------") + print("| You are not up to date with the most recent release. |") + print("| Consider running `git pull` to update. |") + print("--------------------------------------------------------") + elif commits["commit"]["sha"] == commit: + print("You are up to date with the most recent release.") + else: + print("Not a git clone, can't perform version check.") + except Exception as e: + print("version check failed", e) + + +class Seafoam(Base): + def __init__( + self, + *, + primary_hue: colors.Color | str = colors.emerald, + secondary_hue: colors.Color | str = colors.blue, + neutral_hue: colors.Color | str = colors.blue, + spacing_size: sizes.Size | str = sizes.spacing_md, + radius_size: sizes.Size | str = sizes.radius_md, + text_size: sizes.Size | str = sizes.text_lg, + font: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("Quicksand"), + "ui-sans-serif", + "sans-serif", + ), + font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( + fonts.GoogleFont("IBM Plex Mono"), + "ui-monospace", + "monospace", + ), + ): + super().__init__( + primary_hue=primary_hue, + secondary_hue=secondary_hue, + neutral_hue=neutral_hue, + spacing_size=spacing_size, + radius_size=radius_size, + text_size=text_size, + font=font, + font_mono=font_mono, + ) + super().set( + button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", + button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", + button_primary_text_color="white", + button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", + slider_color="*secondary_300", + slider_color_dark="*secondary_600", + block_title_text_weight="600", + block_border_width="3px", + block_shadow="*shadow_drop_lg", + button_shadow="*shadow_drop_lg", + button_small_padding="0px", + button_large_padding="3px", + ) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec3fcac25de3cc7d239c4903403d1a4cd81567b --- /dev/null +++ b/fish_speech/webui/manage.py @@ -0,0 +1,1239 @@ +from __future__ import annotations + +import os + +os.environ["USE_LIBUV"] = "0" +import datetime +import html +import json +import platform +import shutil +import signal +import subprocess +import sys +from pathlib import Path + +import gradio as gr +import psutil +import yaml +from loguru import logger +from tqdm import tqdm + +PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") +sys.path.insert(0, "") +print(sys.path) +cur_work_dir = Path(os.getcwd()).resolve() +print("You are in ", str(cur_work_dir)) + +from fish_speech.i18n import i18n +from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html + +config_path = cur_work_dir / "fish_speech" / "configs" +vqgan_yml_path = config_path / "firefly_gan_vq.yaml" +llama_yml_path = config_path / "text2semantic_finetune.yaml" + +env = os.environ.copy() +env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" + +seafoam = Seafoam() + + +def build_html_error_message(error): + return f""" +
+ {html.escape(error)} +
+ """ + + +def build_html_ok_message(msg): + return f""" +
+ {html.escape(msg)} +
+ """ + + +def build_html_href(link, desc, msg): + return f""" + + {html.escape(msg)} + {desc} + + """ + + +def load_data_in_raw(path): + with open(path, "r", encoding="utf-8") as file: + data = file.read() + return str(data) + + +def kill_proc_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + + +system = platform.system() +p_label = None +p_infer = None +p_tensorboard = None + + +def kill_process(pid): + if system == "Windows": + cmd = "taskkill /t /f /pid %s" % pid + # os.system(cmd) + subprocess.run(cmd) + else: + kill_proc_tree(pid) + + +def change_label(if_label): + global p_label + if if_label == True and p_label is None: + url = "http://localhost:3000" + remote_url = "https://text-labeler.pages.dev/" + try: + p_label = subprocess.Popen( + [ + ( + "asr-label-linux-x64" + if sys.platform == "linux" + else "asr-label-win-x64.exe" + ) + ] + ) + except FileNotFoundError: + logger.warning("asr-label execution not found!") + + yield build_html_href( + link=remote_url, + desc=i18n("Optional online ver"), + msg=i18n("Opened labeler in browser"), + ) + + elif if_label == False and p_label is not None: + kill_process(p_label.pid) + p_label = None + yield build_html_ok_message("Nothing") + + +def clean_infer_cache(): + import tempfile + + temp_dir = Path(tempfile.gettempdir()) + gradio_dir = str(temp_dir / "gradio") + try: + shutil.rmtree(gradio_dir) + logger.info(f"Deleted cached audios: {gradio_dir}") + except PermissionError: + logger.info(f"Permission denied: Unable to delete {gradio_dir}") + except FileNotFoundError: + logger.info(f"{gradio_dir} was not found") + except Exception as e: + logger.info(f"An error occurred: {e}") + + +def change_infer( + if_infer, + host, + port, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, +): + global p_infer + if if_infer == True and p_infer == None: + env = os.environ.copy() + + env["GRADIO_SERVER_NAME"] = host + env["GRADIO_SERVER_PORT"] = port + # 启动第二个进程 + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Inferring interface is launched at {}").format(url) + ) + + clean_infer_cache() + + p_infer = subprocess.Popen( + [ + PYTHON, + "tools/webui.py", + "--decoder-checkpoint-path", + infer_decoder_model, + "--decoder-config-name", + infer_decoder_config, + "--llama-checkpoint-path", + infer_llama_model, + ] + + (["--compile"] if infer_compile == "Yes" else []), + env=env, + ) + + elif if_infer == False and p_infer is not None: + kill_process(p_infer.pid) + p_infer = None + yield build_html_error_message(i18n("Infer interface is closed")) + + +js = load_data_in_raw("fish_speech/webui/js/animate.js") +css = load_data_in_raw("fish_speech/webui/css/style.css") + +data_pre_output = (cur_work_dir / "data").resolve() +default_model_output = (cur_work_dir / "results").resolve() +default_filelist = data_pre_output / "detect.list" +data_pre_output.mkdir(parents=True, exist_ok=True) + +items = [] +dict_items = {} + + +def load_yaml_data_in_fact(yml_path): + with open(yml_path, "r", encoding="utf-8") as file: + yml = yaml.safe_load(file) + return yml + + +def write_yaml_data_in_fact(yml, yml_path): + with open(yml_path, "w", encoding="utf-8") as file: + yaml.safe_dump(yml, file, allow_unicode=True) + return yml + + +def generate_tree(directory, depth=0, max_depth=None, prefix=""): + if max_depth is not None and depth > max_depth: + return "" + + tree_str = "" + files = [] + directories = [] + for item in os.listdir(directory): + if os.path.isdir(os.path.join(directory, item)): + directories.append(item) + else: + files.append(item) + + entries = directories + files + for i, entry in enumerate(entries): + connector = "├── " if i < len(entries) - 1 else "└── " + tree_str += f"{prefix}{connector}{entry}
" + if i < len(directories): + extension = "│ " if i < len(entries) - 1 else " " + tree_str += generate_tree( + os.path.join(directory, entry), + depth + 1, + max_depth, + prefix=prefix + extension, + ) + return tree_str + + +def new_explorer(data_path, max_depth): + return gr.Markdown( + elem_classes=["scrollable-component"], + value=generate_tree(data_path, max_depth=max_depth), + ) + + +def add_item( + folder: str, + method: str, + label_lang: str, + if_initial_prompt: bool, + initial_prompt: str | None, +): + folder = folder.strip(" ").strip('"') + + folder_path = Path(folder) + + if folder and folder not in items and data_pre_output not in folder_path.parents: + if folder_path.is_dir(): + items.append(folder) + dict_items[folder] = dict( + type="folder", + method=method, + label_lang=label_lang, + initial_prompt=initial_prompt if if_initial_prompt else None, + ) + elif folder: + err = folder + return gr.Checkboxgroup(choices=items), build_html_error_message( + i18n("Invalid path: {}").format(err) + ) + + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info("After Adding: " + formatted_data) + gr.Info(formatted_data) + return gr.Checkboxgroup(choices=items), build_html_ok_message( + i18n("Added path successfully!") + ) + + +def remove_items(selected_items): + global items, dict_items + to_remove = [item for item in items if item in selected_items] + for item in to_remove: + del dict_items[item] + items = [item for item in items if item in dict_items.keys()] + formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) + logger.info(formatted_data) + gr.Warning("After Removing: " + formatted_data) + return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( + i18n("Removed path successfully!") + ) + + +def show_selected(options): + selected_options = ", ".join(options) + + if options: + return i18n("Selected: {}").format(selected_options) + else: + return i18n("No selected options") + + +from pydub import AudioSegment + + +def convert_to_mono_in_place(audio_path: Path): + audio = AudioSegment.from_file(audio_path) + if audio.channels > 1: + mono_audio = audio.set_channels(1) + mono_audio.export(audio_path, format=audio_path.suffix[1:]) + logger.info(f"Convert {audio_path} successfully") + + +def list_copy(list_file_path, method): + wav_root = data_pre_output + lst = [] + with list_file_path.open("r", encoding="utf-8") as file: + for line in tqdm(file, desc="Processing audio/transcript"): + wav_path, speaker_name, language, text = line.strip().split("|") + original_wav_path = Path(wav_path) + target_wav_path = ( + wav_root / original_wav_path.parent.name / original_wav_path.name + ) + lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") + if target_wav_path.is_file(): + continue + target_wav_path.parent.mkdir(parents=True, exist_ok=True) + if method == i18n("Copy"): + shutil.copy(original_wav_path, target_wav_path) + else: + shutil.move(original_wav_path, target_wav_path.parent) + convert_to_mono_in_place(target_wav_path) + original_lab_path = original_wav_path.with_suffix(".lab") + target_lab_path = ( + wav_root + / original_wav_path.parent.name + / original_wav_path.with_suffix(".lab").name + ) + if target_lab_path.is_file(): + continue + if method == i18n("Copy"): + shutil.copy(original_lab_path, target_lab_path) + else: + shutil.move(original_lab_path, target_lab_path.parent) + + if method == i18n("Move"): + with list_file_path.open("w", encoding="utf-8") as file: + file.writelines("\n".join(lst)) + + del lst + return build_html_ok_message(i18n("Use filelist")) + + +def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): + global dict_items + data_path = Path(data_path) + gr.Warning("Pre-processing begins...") + for item, content in dict_items.items(): + item_path = Path(item) + tar_path = data_path / item_path.name + + if content["type"] == "folder" and item_path.is_dir(): + if content["method"] == i18n("Copy"): + os.makedirs(tar_path, exist_ok=True) + shutil.copytree( + src=str(item_path), dst=str(tar_path), dirs_exist_ok=True + ) + elif not tar_path.is_dir(): + shutil.move(src=str(item_path), dst=str(tar_path)) + + for suf in ["wav", "flac", "mp3"]: + for audio_path in tar_path.glob(f"**/*.{suf}"): + convert_to_mono_in_place(audio_path) + + cur_lang = content["label_lang"] + initial_prompt = content["initial_prompt"] + + transcribe_cmd = [ + PYTHON, + "tools/whisper_asr.py", + "--model-size", + label_model, + "--device", + label_device, + "--audio-dir", + tar_path, + "--save-dir", + tar_path, + "--language", + cur_lang, + ] + + if initial_prompt is not None: + transcribe_cmd += ["--initial-prompt", initial_prompt] + + if cur_lang != "IGNORE": + try: + gr.Warning("Begin To Transcribe") + subprocess.run( + transcribe_cmd, + env=env, + ) + except Exception: + print("Transcription error occurred") + + elif content["type"] == "file" and item_path.is_file(): + list_copy(item_path, content["method"]) + + return build_html_ok_message(i18n("Move files successfully")), new_explorer( + data_path, max_depth=max_depth + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +def train_process( + data_path: str, + option: str, + # llama config + llama_ckpt, + llama_base_config, + llama_lr, + llama_maxsteps, + llama_data_num_workers, + llama_data_batch_size, + llama_data_max_length, + llama_precision, + llama_check_interval, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, +): + + backend = "nccl" if sys.platform == "linux" else "gloo" + + new_project = generate_folder_name() + print("New Project Name: ", new_project) + + if option == "VQGAN": + msg = "Skipped VQGAN Training." + gr.Warning(msg) + logger.info(msg) + + if option == "LLAMA": + msg = "LLAMA Training begins..." + gr.Warning(msg) + logger.info(msg) + subprocess.run( + [ + PYTHON, + "tools/vqgan/extract_vq.py", + str(data_pre_output), + "--num-workers", + "1", + "--batch-size", + "16", + "--config-name", + "firefly_gan_vq", + "--checkpoint-path", + "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ] + ) + + subprocess.run( + [ + PYTHON, + "tools/llama/build_dataset.py", + "--input", + str(data_pre_output), + "--text-extension", + ".lab", + "--num-workers", + "16", + ] + ) + ckpt_path = "checkpoints/fish-speech-1.4/model.pth" + lora_prefix = "lora_" if llama_use_lora else "" + llama_name = lora_prefix + "text2semantic_" + new_project + latest = next( + iter( + sorted( + [ + str(p.relative_to("results")) + for p in Path("results").glob(lora_prefix + "text2sem*/") + ], + reverse=True, + ) + ), + llama_name, + ) + project = ( + llama_name + if llama_ckpt == i18n("new") + else ( + latest + if llama_ckpt == i18n("latest") + else Path(llama_ckpt).relative_to("results") + ) + ) + logger.info(project) + + if llama_check_interval > llama_maxsteps: + llama_check_interval = llama_maxsteps + + train_cmd = [ + PYTHON, + "fish_speech/train.py", + "--config-name", + "text2semantic_finetune", + f"project={project}", + f"trainer.strategy.process_group_backend={backend}", + f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", + f"model.optimizer.lr={llama_lr}", + f"trainer.max_steps={llama_maxsteps}", + f"data.num_workers={llama_data_num_workers}", + f"data.batch_size={llama_data_batch_size}", + f"max_length={llama_data_max_length}", + f"trainer.precision={llama_precision}", + f"trainer.val_check_interval={llama_check_interval}", + f"trainer.accumulate_grad_batches={llama_grad_batches}", + f"train_dataset.interactive_prob={llama_use_speaker}", + ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else []) + logger.info(train_cmd) + subprocess.run(train_cmd) + + return build_html_ok_message(i18n("Training stopped")) + + +def tensorboard_process( + if_tensorboard: bool, + tensorboard_dir: str, + host: str, + port: str, +): + global p_tensorboard + if if_tensorboard == True and p_tensorboard == None: + url = f"http://{host}:{port}" + yield build_html_ok_message( + i18n("Tensorboard interface is launched at {}").format(url) + ) + prefix = ["tensorboard"] + if Path("fishenv").exists(): + prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"] + + p_tensorboard = subprocess.Popen( + prefix + + [ + "--logdir", + tensorboard_dir, + "--host", + host, + "--port", + port, + "--reload_interval", + "120", + ] + ) + elif if_tensorboard == False and p_tensorboard != None: + kill_process(p_tensorboard.pid) + p_tensorboard = None + yield build_html_error_message(i18n("Tensorboard interface is closed")) + + +def fresh_tb_dir(): + return gr.Dropdown( + choices=[str(p) for p in Path("results").glob("**/tensorboard/")] + ) + + +def list_decoder_models(): + paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")] + if not paths: + logger.warning("No decoder model found") + return paths + + +def list_llama_models(): + choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")] + choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")] + choices = sorted(choices, reverse=True) + if not choices: + logger.warning("No LLaMA model found") + return choices + + +def list_lora_llama_models(): + choices = sorted( + [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True + ) + if not choices: + logger.warning("No LoRA LLaMA model found") + return choices + + +def fresh_decoder_model(): + return gr.Dropdown(choices=list_decoder_models()) + + +def fresh_llama_ckpt(llama_use_lora): + return gr.Dropdown( + choices=[i18n("latest"), i18n("new")] + + ( + [str(p) for p in Path("results").glob("text2sem*/")] + if not llama_use_lora + else [str(p) for p in Path("results").glob("lora_*/")] + ) + ) + + +def fresh_llama_model(): + return gr.Dropdown(choices=list_llama_models()) + + +def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output): + if ( + lora_weight is None + or not Path(lora_weight).exists() + or not Path(llama_weight).exists() + ): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + gr.Warning("Merging begins...") + merge_cmd = [ + PYTHON, + "tools/llama/merge_lora.py", + "--lora-config", + "r_8_alpha_16", + "--lora-weight", + lora_weight, + "--output", + llama_lora_output + "_" + generate_folder_name(), + ] + logger.info(merge_cmd) + subprocess.run(merge_cmd) + return build_html_ok_message(i18n("Merge successfully")) + + +def llama_quantify(llama_weight, quantify_mode): + if llama_weight is None or not Path(llama_weight).exists(): + return build_html_error_message( + i18n( + "Path error, please check the model file exists in the corresponding path" + ) + ) + + gr.Warning("Quantifying begins...") + + now = generate_folder_name() + quantify_cmd = [ + PYTHON, + "tools/llama/quantize.py", + "--checkpoint-path", + llama_weight, + "--mode", + quantify_mode, + "--timestamp", + now, + ] + logger.info(quantify_cmd) + subprocess.run(quantify_cmd) + if quantify_mode == "int8": + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}" + ) + else: + quantize_path = str( + Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}" + ) + return build_html_ok_message( + i18n("Quantify successfully") + f"Path: {quantize_path}" + ) + + +init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) +init_llama_yml = load_yaml_data_in_fact(llama_yml_path) + +with gr.Blocks( + head="", + js=js, + theme=seafoam, + analytics_enabled=False, + title="Fish Speech", +) as demo: + with gr.Row(): + with gr.Column(): + with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")): + with gr.Row(): + textbox = gr.Textbox( + label="\U0000270F " + + i18n("Input Audio & Source Path for Transcription"), + info=i18n("Speaker is identified by the folder name"), + interactive=True, + ) + with gr.Row(equal_height=False): + with gr.Column(): + output_radio = gr.Radio( + label="\U0001F4C1 " + + i18n("Select source file processing method"), + choices=[i18n("Copy"), i18n("Move")], + value=i18n("Copy"), + interactive=True, + ) + with gr.Column(): + error = gr.HTML(label=i18n("Error Message")) + if_label = gr.Checkbox( + label=i18n("Open Labeler WebUI"), scale=0, show_label=True + ) + + with gr.Row(): + label_device = gr.Dropdown( + label=i18n("Labeling Device"), + info=i18n( + "It is recommended to use CUDA, if you have low configuration, use CPU" + ), + choices=["cpu", "cuda"], + value="cuda", + interactive=True, + ) + label_model = gr.Dropdown( + label=i18n("Whisper Model"), + info=i18n("Faster Whisper, Up to 5g GPU memory usage"), + choices=["large-v3", "medium"], + value="large-v3", + interactive=True, + ) + label_radio = gr.Dropdown( + label=i18n("Optional Label Language"), + info=i18n( + "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format" + ), + choices=[ + (i18n("Chinese"), "zh"), + (i18n("English"), "en"), + (i18n("Japanese"), "ja"), + (i18n("Disabled"), "IGNORE"), + (i18n("auto"), "auto"), + ], + value="IGNORE", + interactive=True, + ) + + with gr.Row(): + if_initial_prompt = gr.Checkbox( + value=False, + label=i18n("Enable Initial Prompt"), + min_width=120, + scale=0, + ) + initial_prompt = gr.Textbox( + label=i18n("Initial Prompt"), + info=i18n( + "Initial prompt can provide contextual or vocabulary-specific guidance to the model." + ), + placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.", + interactive=False, + ) + + with gr.Row(): + add_button = gr.Button( + "\U000027A1 " + i18n("Add to Processing Area"), + variant="primary", + ) + remove_button = gr.Button( + "\U000026D4 " + i18n("Remove Selected Data") + ) + + with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")): + with gr.Row(): + model_type_radio = gr.Radio( + label=i18n( + "Select the model to be trained (Depending on the Tab page you are on)" + ), + interactive=False, + choices=["VQGAN", "LLAMA"], + value="VQGAN", + ) + with gr.Row(): + with gr.Tabs(): + with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page: + gr.HTML("You don't need to train this model!") + + with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page: + with gr.Row(equal_height=False): + llama_use_lora = gr.Checkbox( + label=i18n("Use LoRA"), + info=i18n( + "Use LoRA can save GPU memory, but may reduce the quality of the model" + ), + value=True, + interactive=True, + ) + llama_ckpt = gr.Dropdown( + label=i18n("Select LLAMA ckpt"), + choices=[i18n("latest"), i18n("new")] + + [ + str(p) + for p in Path("results").glob("text2sem*/") + ] + + [str(p) for p in Path("results").glob("lora*/")], + value=i18n("latest"), + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lr_slider = gr.Slider( + label=i18n("Initial Learning Rate"), + info=i18n( + "lr smaller -> usually train slower but more stable" + ), + interactive=True, + minimum=1e-5, + maximum=1e-4, + step=1e-5, + value=5e-5, + ) + llama_maxsteps_slider = gr.Slider( + label=i18n("Maximum Training Steps"), + info=i18n( + "recommend: max_steps = num_audios // batch_size * (2 to 5)" + ), + interactive=True, + minimum=1, + maximum=10000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_base_config = gr.Dropdown( + label=i18n("Model Size"), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + ) + llama_data_num_workers_slider = gr.Slider( + label=i18n("Number of Workers"), + minimum=1, + maximum=32, + step=1, + value=4, + ) + with gr.Row(equal_height=False): + llama_data_batch_size_slider = gr.Slider( + label=i18n("Batch Size"), + interactive=True, + minimum=1, + maximum=32, + step=1, + value=2, + ) + llama_data_max_length_slider = gr.Slider( + label=i18n("Maximum Length per Sample"), + interactive=True, + minimum=1024, + maximum=4096, + step=128, + value=2048, + ) + with gr.Row(equal_height=False): + llama_precision_dropdown = gr.Dropdown( + label=i18n("Precision"), + info=i18n( + "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU" + ), + interactive=True, + choices=["32", "bf16-true", "16-mixed"], + value="bf16-true", + ) + llama_check_interval_slider = gr.Slider( + label=i18n("Save model every n steps"), + info=i18n( + "make sure that it's not greater than max_steps" + ), + interactive=True, + minimum=1, + maximum=1000, + step=1, + value=50, + ) + with gr.Row(equal_height=False): + llama_grad_batches = gr.Slider( + label=i18n("Accumulate Gradient Batches"), + interactive=True, + minimum=1, + maximum=20, + step=1, + value=init_llama_yml["trainer"][ + "accumulate_grad_batches" + ], + ) + llama_use_speaker = gr.Slider( + label=i18n( + "Probability of applying Speaker Condition" + ), + interactive=True, + minimum=0.1, + maximum=1.0, + step=0.05, + value=init_llama_yml["train_dataset"][ + "interactive_prob" + ], + ) + + with gr.Tab(label=i18n("Merge LoRA"), id=4): + with gr.Row(equal_height=False): + llama_weight = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "checkpoints/fish-speech-1.4/model.pth", + ], + value="checkpoints/fish-speech-1.4/model.pth", + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + lora_weight = gr.Dropdown( + label=i18n("LoRA Model to be merged"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + str(p) + for p in Path("results").glob("lora*/**/*.ckpt") + ], + allow_custom_value=True, + interactive=True, + ) + lora_llama_config = gr.Dropdown( + label=i18n("LLAMA Model Config"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=[ + "text2semantic_finetune", + ], + value="text2semantic_finetune", + allow_custom_value=True, + ) + with gr.Row(equal_height=False): + llama_lora_output = gr.Dropdown( + label=i18n("Output Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/merged", + choices=["checkpoints/merged"], + allow_custom_value=True, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_lora_merge_btn = gr.Button( + value=i18n("Merge"), variant="primary" + ) + + with gr.Tab(label=i18n("Model Quantization"), id=5): + with gr.Row(equal_height=False): + llama_weight_to_quantify = gr.Dropdown( + label=i18n("Base LLAMA Model"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_llama_models(), + value="checkpoints/fish-speech-1.4", + allow_custom_value=True, + interactive=True, + ) + quantify_mode = gr.Dropdown( + label=i18n("Post-quantification Precision"), + info=i18n( + "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase" + ), + choices=["int8", "int4"], + value="int8", + allow_custom_value=False, + interactive=True, + ) + with gr.Row(equal_height=False): + llama_quantify_btn = gr.Button( + value=i18n("Quantify"), variant="primary" + ) + + with gr.Tab(label="Tensorboard", id=6): + with gr.Row(equal_height=False): + tb_host = gr.Textbox( + label=i18n("Tensorboard Host"), value="127.0.0.1" + ) + tb_port = gr.Textbox( + label=i18n("Tensorboard Port"), value="11451" + ) + with gr.Row(equal_height=False): + tb_dir = gr.Dropdown( + label=i18n("Tensorboard Log Path"), + allow_custom_value=True, + choices=[ + str(p) + for p in Path("results").glob("**/tensorboard/") + ], + ) + with gr.Row(equal_height=False): + if_tb = gr.Checkbox( + label=i18n("Open Tensorboard"), + ) + + with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")): + with gr.Column(): + with gr.Row(): + with gr.Accordion( + label="\U0001F5A5 " + + i18n("Inference Server Configuration"), + open=False, + ): + with gr.Row(): + infer_host_textbox = gr.Textbox( + label=i18n("WebUI Host"), value="127.0.0.1" + ) + infer_port_textbox = gr.Textbox( + label=i18n("WebUI Port"), value="7862" + ) + with gr.Row(): + infer_decoder_model = gr.Dropdown( + label=i18n("Decoder Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + choices=list_decoder_models(), + value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + allow_custom_value=True, + ) + infer_decoder_config = gr.Dropdown( + label=i18n("Decoder Model Config"), + info=i18n("Changing with the Model Path"), + value="firefly_gan_vq", + choices=[ + "firefly_gan_vq", + ], + allow_custom_value=True, + ) + with gr.Row(): + infer_llama_model = gr.Dropdown( + label=i18n("LLAMA Model Path"), + info=i18n( + "Type the path or select from the dropdown" + ), + value="checkpoints/fish-speech-1.4", + choices=list_llama_models(), + allow_custom_value=True, + ) + + with gr.Row(): + infer_compile = gr.Radio( + label=i18n("Compile Model"), + info=i18n( + "Compile the model can significantly reduce the inference time, but will increase cold start time" + ), + choices=["Yes", "No"], + value=( + "Yes" if (sys.platform == "linux") else "No" + ), + interactive=is_module_installed("triton"), + ) + + with gr.Row(): + infer_checkbox = gr.Checkbox( + label=i18n("Open Inference Server") + ) + infer_error = gr.HTML(label=i18n("Inference Server Error")) + + with gr.Column(): + train_error = gr.HTML(label=i18n("Training Error")) + checkbox_group = gr.CheckboxGroup( + label="\U0001F4CA " + i18n("Data Source"), + info=i18n( + "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list." + ), + elem_classes=["data_src"], + ) + train_box = gr.Textbox( + label=i18n("Data Preprocessing Path"), + value=str(data_pre_output), + interactive=False, + ) + model_box = gr.Textbox( + label="\U0001F4BE " + i18n("Model Output Path"), + value=str(default_model_output), + interactive=False, + ) + + with gr.Accordion( + i18n( + "View the status of the preprocessing folder (use the slider to control the depth of the tree)" + ), + elem_classes=["scrollable-component"], + elem_id="file_accordion", + ): + tree_slider = gr.Slider( + minimum=0, + maximum=3, + value=0, + step=1, + show_label=False, + container=False, + ) + file_markdown = new_explorer(str(data_pre_output), 0) + with gr.Row(equal_height=False): + admit_btn = gr.Button( + "\U00002705 " + i18n("File Preprocessing"), + variant="primary", + ) + fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) + help_button = gr.Button("\U00002753", scale=0, min_width=80) # question + train_btn = gr.Button(i18n("Start Training"), variant="primary") + + footer = load_data_in_raw("fish_speech/webui/html/footer.html") + footer = footer.format( + versions=versions_html(), + api_docs="https://speech.fish.audio/inference/#http-api", + ) + gr.HTML(footer, elem_id="footer") + vqgan_page.select(lambda: "VQGAN", None, model_type_radio) + llama_page.select(lambda: "LLAMA", None, model_type_radio) + add_button.click( + fn=add_item, + inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt], + outputs=[checkbox_group, error], + ) + remove_button.click( + fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] + ) + checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) + help_button.click( + fn=None, + js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' + 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', + ) + if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) + if_initial_prompt.change( + fn=lambda x: gr.Textbox(value="", interactive=x), + inputs=[if_initial_prompt], + outputs=[initial_prompt], + ) + train_btn.click( + fn=train_process, + inputs=[ + train_box, + model_type_radio, + # llama config + llama_ckpt, + llama_base_config, + llama_lr_slider, + llama_maxsteps_slider, + llama_data_num_workers_slider, + llama_data_batch_size_slider, + llama_data_max_length_slider, + llama_precision_dropdown, + llama_check_interval_slider, + llama_grad_batches, + llama_use_speaker, + llama_use_lora, + ], + outputs=[train_error], + ) + if_tb.change( + fn=tensorboard_process, + inputs=[if_tb, tb_dir, tb_host, tb_port], + outputs=[train_error], + ) + tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir]) + infer_decoder_model.change( + fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model] + ) + infer_llama_model.change( + fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model] + ) + llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight]) + admit_btn.click( + fn=check_files, + inputs=[train_box, tree_slider, label_model, label_device], + outputs=[error, file_markdown], + ) + fresh_btn.click( + fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] + ) + llama_use_lora.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + llama_ckpt.change( + fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt] + ) + lora_weight.change( + fn=lambda: gr.Dropdown(choices=list_lora_llama_models()), + inputs=[], + outputs=[lora_weight], + ) + llama_lora_merge_btn.click( + fn=llama_lora_merge, + inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output], + outputs=[train_error], + ) + llama_quantify_btn.click( + fn=llama_quantify, + inputs=[llama_weight_to_quantify, quantify_mode], + outputs=[train_error], + ) + infer_checkbox.change( + fn=change_infer, + inputs=[ + infer_checkbox, + infer_host_textbox, + infer_port_textbox, + infer_decoder_model, + infer_decoder_config, + infer_llama_model, + infer_compile, + ], + outputs=[infer_error], + ) + +demo.launch(inbrowser=True) diff --git a/tools/__pycache__/api.cpython-310.pyc b/tools/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..918ab37b9be869d228623b4cd76a8a2f49128a01 Binary files /dev/null and b/tools/__pycache__/api.cpython-310.pyc differ diff --git a/tools/__pycache__/commons.cpython-310.pyc b/tools/__pycache__/commons.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0150f60d0fcc14618f65d1caedf818dddcb38a6 Binary files /dev/null and b/tools/__pycache__/commons.cpython-310.pyc differ diff --git a/tools/__pycache__/file.cpython-310.pyc b/tools/__pycache__/file.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f8618e7f0414499fc7de6f14313a8a47003402a Binary files /dev/null and b/tools/__pycache__/file.cpython-310.pyc differ diff --git a/tools/__pycache__/webui.cpython-310.pyc b/tools/__pycache__/webui.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ae4f833880994224ad8a1b1e6fac966c841f5d1 Binary files /dev/null and b/tools/__pycache__/webui.cpython-310.pyc differ diff --git a/tools/api.py b/tools/api.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5a47d1e90d247e0fb2c23037dea890a5e35f7f --- /dev/null +++ b/tools/api.py @@ -0,0 +1,440 @@ +import base64 +import io +import json +import queue +import random +import sys +import traceback +import wave +from argparse import ArgumentParser +from http import HTTPStatus +from pathlib import Path +from typing import Annotated, Any, Literal, Optional + +import numpy as np +import ormsgpack +import pyrootutils +import soundfile as sf +import torch +import torchaudio +from baize.datastructures import ContentType +from kui.asgi import ( + Body, + FactoryClass, + HTTPException, + HttpRequest, + HttpView, + JSONResponse, + Kui, + OpenAPI, + StreamResponse, +) +from kui.asgi.routing import MultimethodRoutes +from loguru import logger +from pydantic import BaseModel, Field, conint + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +# from fish_speech.models.vqgan.lit_module import VQGAN +from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps +from tools.commons import ServeReferenceAudio, ServeTTSRequest +from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, + launch_thread_safe_queue, +) +from tools.vqgan.inference import load_model as load_decoder_model + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +# Define utils for web server +async def http_execption_handler(exc: HTTPException): + return JSONResponse( + dict( + statusCode=exc.status_code, + message=exc.content, + error=HTTPStatus(exc.status_code).phrase, + ), + exc.status_code, + exc.headers, + ) + + +async def other_exception_handler(exc: "Exception"): + traceback.print_exc() + + status = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse( + dict(statusCode=status, message=str(exc), error=status.phrase), + status, + ) + + +def load_audio(reference_audio, sr): + if len(reference_audio) > 255 or not Path(reference_audio).exists(): + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) + + waveform, original_sr = torchaudio.load( + reference_audio, backend="sox" if sys.platform == "linux" else "soundfile" + ) + + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if original_sr != sr: + resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr) + waveform = resampler(waveform) + + audio = waveform.squeeze().numpy() + return audio + + +def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): + if enable_reference_audio and reference_audio is not None: + # Load audios, and prepare basic info here + reference_audio_content = load_audio( + reference_audio, decoder_model.spec_transform.sample_rate + ) + + audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ + None, None, : + ] + audio_lengths = torch.tensor( + [audios.shape[2]], device=decoder_model.device, dtype=torch.long + ) + logger.info( + f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + if isinstance(decoder_model, FireflyArchitecture): + prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] + + logger.info(f"Encoded prompt: {prompt_tokens.shape}") + else: + prompt_tokens = None + logger.info("No reference audio provided") + + return prompt_tokens + + +def decode_vq_tokens( + *, + decoder_model, + codes, +): + feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) + logger.info(f"VQ features: {codes.shape}") + + if isinstance(decoder_model, FireflyArchitecture): + # VQGAN Inference + return decoder_model.decode( + indices=codes[None], + feature_lengths=feature_lengths, + )[0].squeeze() + + raise ValueError(f"Unknown model type: {type(decoder_model)}") + + +routes = MultimethodRoutes(base_class=HttpView) + + +def get_content_type(audio_format): + if audio_format == "wav": + return "audio/wav" + elif audio_format == "flac": + return "audio/flac" + elif audio_format == "mp3": + return "audio/mpeg" + else: + return "application/octet-stream" + + +@torch.inference_mode() +def inference(req: ServeTTSRequest): + + idstr: str | None = req.reference_id + if idstr is not None: + ref_folder = Path("references") / idstr + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + + else: + # Parse reference audio aka prompt + refs = req.references + if refs is None: + refs = [] + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + + # LLAMA Inference + request = dict( + device=decoder_model.device, + max_new_tokens=req.max_new_tokens, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), + top_p=req.top_p, + repetition_penalty=req.repetition_penalty, + temperature=req.temperature, + compile=args.compile, + iterative_prompt=req.chunk_length > 0, + chunk_length=req.chunk_length, + max_length=2048, + prompt_tokens=prompt_tokens, + prompt_text=prompt_texts, + ) + + response_queue = queue.Queue() + llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + if req.streaming: + yield wav_chunk_header() + + segments = [] + while True: + result: WrappedGenerateResponse = response_queue.get() + if result.status == "error": + raise result.response + break + + result: GenerateResponse = result.response + if result.action == "next": + break + + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision + ): + fake_audios = decode_vq_tokens( + decoder_model=decoder_model, + codes=result.codes, + ) + + fake_audios = fake_audios.float().cpu().numpy() + + if req.streaming: + yield (fake_audios * 32768).astype(np.int16).tobytes() + else: + segments.append(fake_audios) + + if req.streaming: + return + + if len(segments) == 0: + raise HTTPException( + HTTPStatus.INTERNAL_SERVER_ERROR, + content="No audio generated, please check the input text.", + ) + + fake_audios = np.concatenate(segments, axis=0) + yield fake_audios + + +async def inference_async(req: ServeTTSRequest): + for chunk in inference(req): + yield chunk + + +async def buffer_to_async_generator(buffer): + yield buffer + + +@routes.http.post("/v1/tts") +async def api_invoke_model( + req: Annotated[ServeTTSRequest, Body(exclusive=True)], +): + """ + Invoke model and generate audio + """ + + if args.max_text_length > 0 and len(req.text) > args.max_text_length: + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content=f"Text is too long, max length is {args.max_text_length}", + ) + + if req.streaming and req.format != "wav": + raise HTTPException( + HTTPStatus.BAD_REQUEST, + content="Streaming only supports WAV format", + ) + + if req.streaming: + return StreamResponse( + iterable=inference_async(req), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + else: + fake_audios = next(inference(req)) + buffer = io.BytesIO() + sf.write( + buffer, + fake_audios, + decoder_model.spec_transform.sample_rate, + format=req.format, + ) + + return StreamResponse( + iterable=buffer_to_async_generator(buffer.getvalue()), + headers={ + "Content-Disposition": f"attachment; filename=audio.{req.format}", + }, + content_type=get_content_type(req.format), + ) + + +@routes.http.post("/v1/health") +async def api_health(): + """ + Health check + """ + + return JSONResponse({"status": "ok"}) + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=str, + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-text-length", type=int, default=0) + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") + parser.add_argument("--workers", type=int, default=1) + + return parser.parse_args() + + +# Define Kui app +openapi = OpenAPI( + { + "title": "Fish Speech API", + }, +).routes + + +class MsgPackRequest(HttpRequest): + async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack"}, + ) + + +app = Kui( + routes=routes + openapi[1:], # Remove the default route + exception_handlers={ + HTTPException: http_execption_handler, + Exception: other_exception_handler, + }, + factory_class=FactoryClass(http=MsgPackRequest), + cors_config={}, +) + + +if __name__ == "__main__": + + import uvicorn + + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + logger.info("Llama model loaded, loading VQ-GAN model...") + + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("VQ-GAN model loaded, warming up...") + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference( + ServeTTSRequest( + text="Hello world.", + references=[], + reference_id=None, + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.2, + temperature=0.7, + emotion=None, + format="wav", + ) + ) + ) + + logger.info(f"Warming up done, starting server at http://{args.listen}") + host, port = args.listen.split(":") + uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") diff --git a/tools/commons.py b/tools/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..f81cadec1efd6e4f749c279e64a65ea9caaa3f53 --- /dev/null +++ b/tools/commons.py @@ -0,0 +1,35 @@ +from typing import Annotated, Literal, Optional + +from pydantic import BaseModel, Field, conint + + +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str + + +class ServeTTSRequest(BaseModel): + text: str + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + mp3_bitrate: Literal[64, 128, 192] = 128 + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + mp3_bitrate: Optional[int] = 64 + opus_bitrate: Optional[int] = -1000 + # Balance mode will reduce latency to 300ms, but may decrease stability + latency: Literal["normal", "balanced"] = "normal" + # not usually used below + streaming: bool = False + emotion: Optional[str] = None + max_new_tokens: int = 1024 + top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 + temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/tools/download_models.py b/tools/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..9e79c34c43b424a8e47c43dd3edf003634fc667e --- /dev/null +++ b/tools/download_models.py @@ -0,0 +1,55 @@ +import os + +from huggingface_hub import hf_hub_download + + +# Download +def check_and_download_files(repo_id, file_list, local_dir): + os.makedirs(local_dir, exist_ok=True) + for file in file_list: + file_path = os.path.join(local_dir, file) + if not os.path.exists(file_path): + print(f"{file} 不存在,从 Hugging Face 仓库下载...") + hf_hub_download( + repo_id=repo_id, + filename=file, + resume_download=True, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + else: + print(f"{file} 已存在,跳过下载。") + + +# 1st +repo_id_1 = "fishaudio/fish-speech-1.4" +local_dir_1 = "./checkpoints/fish-speech-1.4" +files_1 = [ + "model.pth", + "README.md", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + "config.json", + "firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +] + +# 3rd +repo_id_3 = "fishaudio/fish-speech-1" +local_dir_3 = "./" +files_3 = [ + "ffmpeg.exe", + "ffprobe.exe", +] + +# 4th +repo_id_4 = "SpicyqSama007/fish-speech-packed" +local_dir_4 = "./" +files_4 = [ + "asr-label-win-x64.exe", +] + +check_and_download_files(repo_id_1, files_1, local_dir_1) + +check_and_download_files(repo_id_3, files_3, local_dir_3) +check_and_download_files(repo_id_4, files_4, local_dir_4) diff --git a/tools/extract_model.py b/tools/extract_model.py new file mode 100644 index 0000000000000000000000000000000000000000..97fe62507b7282890319d8dc1eaa3cbca0e1f60a --- /dev/null +++ b/tools/extract_model.py @@ -0,0 +1,21 @@ +import click +import torch +from loguru import logger + + +@click.command() +@click.argument("model_path") +@click.argument("output_path") +def main(model_path, output_path): + if model_path == output_path: + logger.error("Model path and output path are the same") + return + + logger.info(f"Loading model from {model_path}") + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + torch.save(state_dict, output_path) + logger.info(f"Model saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/file.py b/tools/file.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a0597365252e7aecf887897ff391a061275c3f --- /dev/null +++ b/tools/file.py @@ -0,0 +1,125 @@ +import base64 +from pathlib import Path +from typing import Union + +from loguru import logger +from natsort import natsorted + +AUDIO_EXTENSIONS = { + ".mp3", + ".wav", + ".flac", + ".ogg", + ".m4a", + ".wma", + ".aac", + ".aiff", + ".aif", + ".aifc", +} + +VIDEO_EXTENSIONS = { + ".mp4", + ".avi", +} + + +def audio_to_bytes(file_path): + if not file_path or not Path(file_path).exists(): + return None + with open(file_path, "rb") as wav_file: + wav = wav_file.read() + return wav + + +def read_ref_text(ref_text): + path = Path(ref_text) + if path.exists() and path.is_file(): + with path.open("r", encoding="utf-8") as file: + return file.read() + return ref_text + + +def list_files( + path: Union[Path, str], + extensions: set[str] = None, + recursive: bool = False, + sort: bool = True, +) -> list[Path]: + """List files in a directory. + + Args: + path (Path): Path to the directory. + extensions (set, optional): Extensions to filter. Defaults to None. + recursive (bool, optional): Whether to search recursively. Defaults to False. + sort (bool, optional): Whether to sort the files. Defaults to True. + + Returns: + list: List of files. + """ + + if isinstance(path, str): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Directory {path} does not exist.") + + files = [file for ext in extensions for file in path.rglob(f"*{ext}")] + + if sort: + files = natsorted(files) + + return files + + +def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]: + """ + Load a Bert-VITS2 style filelist. + """ + + files = set() + results = [] + count_duplicated, count_not_found = 0, 0 + + LANGUAGE_TO_LANGUAGES = { + "zh": ["zh", "en"], + "jp": ["jp", "en"], + "en": ["en"], + } + + with open(path, "r", encoding="utf-8") as f: + for line in f.readlines(): + splits = line.strip().split("|", maxsplit=3) + if len(splits) != 4: + logger.warning(f"Invalid line: {line}") + continue + + filename, speaker, language, text = splits + file = Path(filename) + language = language.strip().lower() + + if language == "ja": + language = "jp" + + assert language in ["zh", "jp", "en"], f"Invalid language {language}" + languages = LANGUAGE_TO_LANGUAGES[language] + + if file in files: + logger.warning(f"Duplicated file: {file}") + count_duplicated += 1 + continue + + if not file.exists(): + logger.warning(f"File not found: {file}") + count_not_found += 1 + continue + + results.append((file, speaker, languages, text)) + + if count_duplicated > 0: + logger.warning(f"Total duplicated files: {count_duplicated}") + + if count_not_found > 0: + logger.warning(f"Total files not found: {count_not_found}") + + return results diff --git a/tools/llama/__pycache__/generate.cpython-310.pyc b/tools/llama/__pycache__/generate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1ed4a971b86371aabf24f6bb3cdcf38135af5b Binary files /dev/null and b/tools/llama/__pycache__/generate.cpython-310.pyc differ diff --git a/tools/llama/build_dataset.py b/tools/llama/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5ef120cce2e04b24f0f897e49f022cb1946c97 --- /dev/null +++ b/tools/llama/build_dataset.py @@ -0,0 +1,169 @@ +import itertools +import os +import re +from collections import defaultdict +from functools import partial +from multiprocessing import Pool +from pathlib import Path + +import click +import numpy as np +from loguru import logger +from tqdm import tqdm + +from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData +from fish_speech.datasets.protos.text_data_stream import pack_pb_stream +from tools.file import load_filelist + +# To avoid CPU overload +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + + +def task_generator_folder(root: Path, text_extension: str): + files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) + files = sorted(files) + + grouped_files = defaultdict(list) + for file in tqdm(files, desc=f"Grouping {root}"): + p = str(file.parent) + speaker = file.parent.name + + try: + if isinstance(text_extension, str): + texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] + else: + texts = [ + file.with_suffix(ext).read_text(encoding="utf-8") + for ext in text_extension + ] + except Exception as e: + logger.error(f"Failed to read text {file}: {e}") + continue + + grouped_files[p].append((speaker, file, texts)) + + logger.info( + f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." + ) + + for i in grouped_files.values(): + subset = [(f, t) for _, f, t in i] + yield i[0][0], subset, "folder" + + +def task_generator_filelist(filelist): + grouped_files = defaultdict(list) + for filename, speaker, _, text in load_filelist(filelist): + grouped_files[speaker].append((Path(filename), [text])) + + logger.info(f"Found {len(grouped_files)} groups in {filelist}") + for speaker, values in grouped_files.items(): + yield speaker, values, "filelist" + + +def run_task(task): + name, subset, source = task + + # Parse the files + sentences = [] + for file, texts in subset: + np_file = file.with_suffix(".npy") + if np_file.exists() is False: + logger.warning(f"Can't find {np_file}") + continue + + new_texts = [] + + for text in texts: + # Simple cleaning: replace { xxx } and < xxx > with space + text = re.sub(r"\{.*?\}", " ", text) + text = re.sub(r"<.*?>", " ", text) + text = re.sub(r"\s+", " ", text) + new_texts.append(text) + + try: + semantics = np.load(np_file) + except Exception as e: + logger.error(f"Failed to parse {file}: {e}") + continue + + if isinstance(semantics, np.ndarray): + semantics = semantics.tolist() + + sentences.append( + Sentence( + texts=new_texts, + semantics=[Semantics(values=s) for s in semantics], + ) + ) + + # Pack the sentences + return pack_pb_stream( + TextData( + source=source, + name=name, + sentences=sentences, + ) + ) + + +@click.command() +@click.option( + "--input", + type=click.Path(path_type=Path), + required=True, + help="A folder containing the dataset or a filelist", + multiple=True, +) +@click.option( + "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" +) +@click.option("--num-workers", type=int, default=16) +@click.option("--text-extension", type=str, default=[".txt"], multiple=True) +@click.option( + "--shard-size", type=int, default=10, help="The maximum size of each shard in mb" +) +def main(input, output, num_workers, text_extension, shard_size): + generator_fns = [] + + for f in input: + assert f.exists(), f"{f} not found" + + if f.is_dir(): + generator_fn = task_generator_folder(f, text_extension) + else: + generator_fn = task_generator_filelist(f) + + generator_fns.append(generator_fn) + + generator_fn = itertools.chain(*generator_fns) + output.mkdir(parents=True, exist_ok=True) + + dataset_fp = None + tar_idx = 0 + written_size = 0 + + with Pool(num_workers) as p: + for result in tqdm(p.imap_unordered(run_task, generator_fn)): + if dataset_fp is None: + dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") + + dataset_fp.write(result) + written_size += len(result) + + if written_size > shard_size * 1024 * 1024: + logger.info(f"Finished writing {tar_idx} shards to {output}") + dataset_fp.close() + dataset_fp = None + written_size = 0 + tar_idx += 1 + + if dataset_fp is not None: + dataset_fp.close() + + logger.info(f"Finished writing {tar_idx + 1} shards to {output}") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/eval_in_context.py b/tools/llama/eval_in_context.py new file mode 100644 index 0000000000000000000000000000000000000000..30d70940487388185381246d8210a49a58e55743 --- /dev/null +++ b/tools/llama/eval_in_context.py @@ -0,0 +1,171 @@ +import pyrootutils +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from transformers import AutoTokenizer + +# register eval resolver and root +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +from torch.utils.data import DataLoader + +from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator +from tools.llama.generate import load_model + + +def smooth( + scalars: list[float], weight: float +) -> list[float]: # Weight between 0 and 1 + last = scalars[0] # First value in the plot (first timestep) + smoothed = list() + for point in scalars: + smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value + smoothed.append(smoothed_val) # Save it + last = smoothed_val # Anchor the last smoothed value + + return smoothed + + +@torch.inference_mode() +def analyze_one_model(loader, config, weight, max_length): + device = "cuda" if torch.cuda.is_available() else "cpu" + model = load_model( + config, + weight, + device, + torch.bfloat16, + max_length, + compile=False, + )[0] + + current_step = 0 + model.eval() + + semantic_loss_sum = torch.zeros( + max_length, + dtype=torch.float32, + device=device, + ) + counter = torch.zeros( + max_length, + dtype=torch.long, + device=device, + ) + + for batch in loader: + batch = {k: v.to(device) for k, v in batch.items()} + + labels = batch["labels"] + outputs = model( + inp=batch["inputs"], + key_padding_mask=batch["attention_masks"], + ) + + token_logits = outputs.token_logits + codebook_logits = outputs.codebook_logits + + # Generate labels + base_loss = F.cross_entropy( + token_logits.reshape(-1, token_logits.size(-1)), + labels[:, 0].reshape(-1), + ignore_index=-100, + reduction="none", + ) + + codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT + semantic_loss = F.cross_entropy( + codebook_logits.reshape(-1, codebook_logits.size(-1)), + codebook_labels.reshape(-1), + ignore_index=-100, + reduction="none", + ) + + base_loss = base_loss.reshape(labels[:, 0].shape) + semantic_loss = semantic_loss.reshape(codebook_labels.shape) + + semantic_loss_frame = semantic_loss.mean(-1) + pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks + + for loss_sample, pad in zip(semantic_loss_frame, pad_pos): + semantic_loss_sum[~pad] += loss_sample[~pad] + counter[~pad] += 1 + + current_step += 1 + if current_step == 10: + break + + semantic_loss = semantic_loss.cpu() + counter = counter.cpu() + xs, ys = [], [] + + for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)): + if count > 0: + xs.append(i) + ys.append((loss / count).item()) # for better loss visualization + + smoothed_ys = smooth(ys, 0.95) + + # Unload model + del model + torch.cuda.empty_cache() + + return xs, ys, smoothed_ys + + +def main(): + tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1") + max_length = 4096 + + ds = AutoAugTextDataset( + ["data/protos/sft/云天河"], + tokenizer=tokenizer, + use_speaker=False, + interactive_prob=1.0, + max_length=max_length, + ) + + loader = DataLoader( + ds, + batch_size=8, + collate_fn=TextDataCollator(tokenizer, max_length=max_length), + num_workers=0, + shuffle=False, + ) + + plt.figure(figsize=(10, 5), dpi=200) + + plt.xlabel("Frame") + plt.ylabel("Loss") + plt.yscale("log") + plt.title("Semantic Loss") + plt.grid(which="both", axis="both") + plt.xlim(0, max_length) + + tests = [ + ( + "pertrain-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-pretrain-medium-2k-v1.pth", + ), + ( + "sft-medium", + "dual_ar_2_codebook_medium", + "checkpoints/text2semantic-sft-medium-v1.1-4k.pth", + ), + ( + "sft-large", + "dual_ar_2_codebook_large", + "checkpoints/text2semantic-sft-large-v1.1-4k.pth", + ), + ] + + for name, config, weight in tests: + xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length) + plt.plot(xs, smoothed_ys, label=name) + + plt.legend() + plt.savefig("semantic_loss.png") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/generate.py b/tools/llama/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..d717ce7c570933c9cf98e6cd64eacf476affaad5 --- /dev/null +++ b/tools/llama/generate.py @@ -0,0 +1,706 @@ +import os +import queue +import threading +import time +from contextlib import nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional, Tuple, Union + +import click +import hydra +import numpy as np +import torch +import torch._dynamo.config +import torch._inductor.config +from loguru import logger +from tqdm import tqdm + +from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID +from fish_speech.text import clean_text, split_text + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +if hasattr(torch._inductor.config, "fx_graph_cache"): + # Experimental feature to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + + +from fish_speech.models.text2semantic.llama import ( + BaseTransformer, + DualARTransformer, + NaiveTransformer, +) + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: torch.Tensor = 1.0, + top_p: torch.Tensor = 1.0, + repetition_penalty: torch.Tensor = 1.0, +) -> torch.Tensor: + # Apply repetition penalty + if previous_tokens is not None: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + + # Apply top-p sampling + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample( + logits, + previous_tokens: Optional[torch.Tensor] = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + probs = logits_to_probs( + logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def decode_one_token_ar( + model: DualARTransformer, + x: torch.Tensor, + input_pos: torch.Tensor, + previous_tokens: torch.Tensor = None, + **sampling_kwargs, +) -> torch.Tensor: + x = model.forward_generate(x, input_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + sampling_kwargs_main["temperature"] = 0.1 + sampling_kwargs_main["top_p"] = 0.1 + sampling_kwargs_main["repetition_penalty"] = 1.0 + + codebooks = [ + sample( + x.logits, + previous_tokens=None, # Disable repetition penalty for the token codebook + **sampling_kwargs_main, + )[0] + ] + + x = x.hidden_states + + # Cleanup the cache + for layer in model.fast_layers: + layer.attention.kv_cache.k_cache.fill_(0) + layer.attention.kv_cache.v_cache.fill_(0) + + for codebook_idx in range(model.config.num_codebooks): + input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long) + logits = model.forward_generate_fast(x, input_pos) + a = sample( + logits, + previous_tokens=( + previous_tokens[codebook_idx + 1] + if previous_tokens is not None + else None + ), + **sampling_kwargs, + )[0] + x = model.fast_embeddings(a) + codebooks.append(a) + + return torch.stack(codebooks, dim=0) + + +def decode_one_token_naive( + model: NaiveTransformer, + x: torch.Tensor, + input_pos: torch.Tensor, + previous_tokens: torch.Tensor = None, + **sampling_kwargs, +) -> torch.Tensor: + x = model.forward_generate(x, input_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + sampling_kwargs_main["temperature"] = 0.1 + sampling_kwargs_main["top_p"] = 0.1 + sampling_kwargs_main["repetition_penalty"] = 1.0 + + codebooks = [ + sample( + x.logits, + previous_tokens=None, # Disable repetition penalty for the token codebook + **sampling_kwargs_main, + )[0] + ] + + for i in range(model.config.num_codebooks): + codebooks.append( + sample( + x.codebook_logits[:, :, i], + previous_tokens=( + previous_tokens[i + 1] if previous_tokens is not None else None + ), + **sampling_kwargs, + )[0] + ) + + return torch.stack(codebooks, dim=0) + + +def decode_n_tokens( + model: NaiveTransformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + im_end_id: int = 4, + decode_one_token=decode_one_token_naive, + **sampling_kwargs, +): + previous_tokens = torch.zeros( + (model.config.num_codebooks + 1, model.config.max_seq_len), + dtype=torch.int, + device=cur_token.device, + ) + + for i in tqdm(range(num_new_tokens)): + # We need to get windowed repeat penalty + win_size = 16 + if i < win_size: + window = previous_tokens[:, :win_size] + else: + window = previous_tokens[:, i - win_size : i] + + with ( + torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ) + if torch.cuda.is_available() + else nullcontext() + ): # Actually better for Inductor to codegen attention here + next_token = decode_one_token( + model=model, + x=cur_token, + input_pos=input_pos, + previous_tokens=window, + **sampling_kwargs, + ) + + input_pos += 1 + cur_token = next_token.view(1, model.config.num_codebooks + 1, -1) + previous_tokens[:, i : i + 1] = next_token.view( + model.config.num_codebooks + 1, -1 + ) + + if cur_token[0, 0, -1] == im_end_id: + break + + return previous_tokens[:, : i + 1] + + +@torch.no_grad() +@torch.inference_mode() +def generate( + *, + model: NaiveTransformer, + prompt: torch.Tensor, + max_new_tokens: int, + im_end_id: int = 4, + decode_one_token=decode_one_token_naive, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(1) + + device, dtype = prompt.device, prompt.dtype + + codebook_dim = 1 + model.config.num_codebooks + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device) + empty[:, :T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + # Use non-accelerated version for now, to avoid compilation overhead + prefill_decode = ( + decode_one_token_naive + if isinstance(model, NaiveTransformer) + else decode_one_token_ar + ) + + next_token = prefill_decode( + model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs + ) + seq[:, T : T + 1] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + x = decode_n_tokens( + model, + next_token.view(1, codebook_dim, -1), + input_pos, + max_new_tokens - 1, + im_end_id=im_end_id, + decode_one_token=decode_one_token, + **sampling_kwargs, + ) + # x = torch.cat(generated_tokens, dim=1) + seq = seq[:, : T + 1 + x.size(1)] + seq[:, T + 1 :] = x + + return seq + + +def encode_tokens( + tokenizer, + string, + device="cuda", + prompt_tokens=None, + num_codebooks=4, +): + string = clean_text(string) + string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n" + + new_tokens = tokenizer.encode( + string, + add_special_tokens=False, + max_length=10**6, + truncation=False, + ) + tokens = torch.tensor([new_tokens], dtype=torch.int, device=device) + + # Codebooks + zeros = ( + torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device) + * CODEBOOK_PAD_TOKEN_ID + ) + prompt = torch.cat((tokens, zeros), dim=0) + + if prompt_tokens is None: + return prompt + + # Get prompt tokens + if prompt_tokens.ndim == 3: + assert ( + prompt_tokens.shape[0] == 1 + ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)" + prompt_tokens = prompt_tokens[0] + + assert prompt_tokens.ndim == 2 + data = prompt_tokens + 1 + + if prompt_tokens.shape[0] > num_codebooks: + logger.warning( + f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks" + ) + data = data[:num_codebooks] + + # Add pad token for each codebook + data = torch.cat( + (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)), + dim=1, + ) + + # Since 1.0, we use <|semantic|> + s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>") + end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + main_token_ids = ( + torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id + ) + main_token_ids[0, -1] = end_token_id + + data = torch.cat((main_token_ids, data), dim=0) + prompt = torch.cat((prompt, data), dim=1) + + return prompt + + +def load_model(checkpoint_path, device, precision, compile=False): + model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained( + checkpoint_path, load_weights=True + ) + + model = model.to(device=device, dtype=precision) + logger.info(f"Restored model from checkpoint") + + if isinstance(model, DualARTransformer): + decode_one_token = decode_one_token_ar + logger.info("Using DualARTransformer") + else: + decode_one_token = decode_one_token_naive + logger.info("Using NaiveTransformer") + + if compile: + logger.info("Compiling function...") + decode_one_token = torch.compile( + decode_one_token, + fullgraph=True, + backend="inductor" if torch.cuda.is_available() else "aot_eager", + mode="reduce-overhead" if torch.cuda.is_available() else None, + ) + + return model.eval(), decode_one_token + + +@dataclass +class GenerateResponse: + action: Literal["sample", "next"] + codes: Optional[torch.Tensor] = None + text: Optional[str] = None + + +def generate_long( + *, + model, + device: str | torch.device, + decode_one_token: callable, + text: str, + num_samples: int = 1, + max_new_tokens: int = 0, + top_p: int = 0.7, + repetition_penalty: float = 1.5, + temperature: float = 0.7, + compile: bool = False, + iterative_prompt: bool = True, + max_length: int = 2048, + chunk_length: int = 150, + prompt_text: Optional[str | list[str]] = None, + prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None, +): + assert 0 < top_p <= 1, "top_p must be in (0, 1]" + assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)" + assert 0 < temperature < 2, "temperature must be in (0, 2)" + + use_prompt = prompt_text is not None and prompt_tokens is not None + if use_prompt and isinstance(prompt_text, str): + prompt_text = [prompt_text] + prompt_tokens = [prompt_tokens] + + assert use_prompt is False or len(prompt_text) == len( + prompt_tokens + ), "Prompt text and tokens must have the same length" + + model_size = sum(p.numel() for p in model.parameters() if p.requires_grad) + tokenizer = model.tokenizer + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + + encoded = [] + texts = split_text(text, chunk_length) if iterative_prompt else [text] + encoded_prompts = [] + + if use_prompt: + for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)): + encoded_prompts.append( + encode_tokens( + tokenizer, + string=t, + device=device, + prompt_tokens=c, + num_codebooks=model.config.num_codebooks, + ) + ) + + for idx, text in enumerate(texts): + encoded.append( + encode_tokens( + tokenizer, + string=text, + device=device, + num_codebooks=model.config.num_codebooks, + ) + ) + logger.info(f"Encoded text: {text}") + + # Move temperature, top_p, repetition_penalty to device + # This is important so that changing params doesn't trigger recompile + temperature = torch.tensor(temperature, device=device, dtype=torch.float) + top_p = torch.tensor(top_p, device=device, dtype=torch.float) + repetition_penalty = torch.tensor( + repetition_penalty, device=device, dtype=torch.float + ) + + for sample_idx in range(num_samples): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + global_encoded = [] + seg_idx = 0 + + while seg_idx < len(encoded): + logger.info( + f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}" + ) + + seg = encoded[seg_idx] + global_encoded.append(seg) + + lengths = reversed([seg.size(1) for seg in global_encoded]) + + # Pick last 2000 tokens + count = 0 + for i, length in enumerate(lengths): + count += length + if count + length > max_length - 1024 - sum( + t.shape[1] for t in encoded_prompts + ): + break + + if i != 0 and i % 2 == 0: + i -= 1 + + # Rotate the list, always make sure first segment is included to avoid drift + if i < len(global_encoded) - 2: + partial_encoded = global_encoded[:2] + global_encoded[-i:] + else: + partial_encoded = global_encoded + + if use_prompt: + partial_encoded = encoded_prompts + partial_encoded + + cat_encoded = torch.cat(partial_encoded, dim=1) + prompt_length = cat_encoded.size(1) + + t0 = time.perf_counter() + y = generate( + model=model, + prompt=cat_encoded, + max_new_tokens=max_new_tokens, + im_end_id=im_end_id, + decode_one_token=decode_one_token, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + if sample_idx == 0 and seg_idx == 0 and compile: + logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + t = time.perf_counter() - t0 + + tokens_generated = y.size(1) - prompt_length + tokens_sec = tokens_generated / t + logger.info( + f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec" + ) + logger.info( + f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s" + ) + + if torch.cuda.is_available(): + logger.info( + f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB" + ) + + # Put the generated tokens + # since there is and tokens, we remove last 2 tokens + codes = y[1:, prompt_length:-1].clone() + codes = codes - 1 + assert (codes >= 0).all(), f"Negative code found" + + decoded = y[:, prompt_length:-1].clone() + # But for global encoding, we should keep the token + + global_encoded.append(decoded) + assert (codes >= 0).all(), f"Negative code found: {codes}" + yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx]) + seg_idx += 1 + + # This indicates the end of the current sample + yield GenerateResponse(action="next") + + +@dataclass +class WrappedGenerateResponse: + status: Literal["success", "error"] + response: Optional[GenerateResponse | Exception] = None + + +@dataclass +class GenerateRequest: + request: dict + response_queue: queue.Queue + + +def launch_thread_safe_queue( + checkpoint_path, + device, + precision, + compile: bool = False, +): + input_queue = queue.Queue() + init_event = threading.Event() + + def worker(): + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype + ) + init_event.set() + + while True: + item: GenerateRequest | None = input_queue.get() + if item is None: + break + + kwargs = item.request + response_queue = item.response_queue + + try: + for chunk in generate_long( + model=model, decode_one_token=decode_one_token, **kwargs + ): + response_queue.put( + WrappedGenerateResponse(status="success", response=chunk) + ) + except Exception as e: + response_queue.put(WrappedGenerateResponse(status="error", response=e)) + + threading.Thread(target=worker, daemon=True).start() + init_event.wait() + + return input_queue + + +@click.command() +@click.option( + "--text", + type=str, + default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", +) +@click.option("--prompt-text", type=str, default=None, multiple=True) +@click.option( + "--prompt-tokens", + type=click.Path(path_type=Path, exists=True), + default=None, + multiple=True, +) +@click.option("--num-samples", type=int, default=1) +@click.option("--max-new-tokens", type=int, default=1024) +@click.option("--top-p", type=float, default=0.7) +@click.option("--repetition-penalty", type=float, default=1.2) +@click.option("--temperature", type=float, default=0.7) +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/fish-speech-1.4", +) +@click.option("--device", type=str, default="cuda") +@click.option("--compile/--no-compile", default=False) +@click.option("--seed", type=int, default=42) +@click.option("--half/--no-half", default=False) +@click.option("--iterative-prompt/--no-iterative-prompt", default=True) +@click.option("--chunk-length", type=int, default=100) +def main( + text: str, + prompt_text: Optional[list[str]], + prompt_tokens: Optional[list[Path]], + num_samples: int, + max_new_tokens: int, + top_p: int, + repetition_penalty: float, + temperature: float, + checkpoint_path: Path, + device: str, + compile: bool, + seed: int, + half: bool, + iterative_prompt: bool, + chunk_length: int, +) -> None: + + precision = torch.half if half else torch.bfloat16 + + if prompt_text is not None and len(prompt_text) != len(prompt_tokens): + raise ValueError( + f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same" + ) + + logger.info("Loading model ...") + t0 = time.time() + model, decode_one_token = load_model( + checkpoint_path, device, precision, compile=compile + ) + with torch.device(device): + model.setup_caches( + max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") + + if prompt_tokens is not None: + prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens] + + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + generator = generate_long( + model=model, + device=device, + decode_one_token=decode_one_token, + text=text, + num_samples=num_samples, + max_new_tokens=max_new_tokens, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + compile=compile, + iterative_prompt=iterative_prompt, + chunk_length=chunk_length, + prompt_text=prompt_text, + prompt_tokens=prompt_tokens, + ) + + idx = 0 + codes = [] + + for response in generator: + if response.action == "sample": + codes.append(response.codes) + logger.info(f"Sampled text: {response.text}") + elif response.action == "next": + if codes: + np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy()) + logger.info(f"Saved codes to codes_{idx}.npy") + logger.info(f"Next sample") + codes = [] + idx += 1 + else: + logger.error(f"Error: {response}") + + +if __name__ == "__main__": + main() diff --git a/tools/llama/merge_lora.py b/tools/llama/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bd3cbd725c4eccbe78f711d9718dfb278a6aa7 --- /dev/null +++ b/tools/llama/merge_lora.py @@ -0,0 +1,95 @@ +import shutil +from copy import deepcopy +from pathlib import Path + +import click +import hydra +import torch +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger + +from fish_speech.models.text2semantic.llama import BaseTransformer +from fish_speech.models.text2semantic.lora import get_merged_state_dict + + +@click.command() +@click.option("--lora-config", type=str, default="r_8_alpha_16") +@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4") +@click.option("--lora-weight", type=str, required=True) +@click.option("--output", type=str, required=True) +def merge(lora_config, base_weight, lora_weight, output): + output = Path(output) + logger.info( + f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}" + ) + + with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"): + cfg = compose(config_name=lora_config) + + lora_config = instantiate(cfg) + logger.info(f"Loaded lora model with config {lora_config}") + + llama_model = BaseTransformer.from_pretrained( + path=base_weight, + load_weights=True, + lora_config=lora_config, + ) + logger.info(f"Loaded llama model") + + llama_state_dict = llama_model.state_dict() + llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k} + llama_state_dict_copy = deepcopy(llama_state_dict) + lora_state_dict = torch.load(lora_weight, map_location="cpu") + + if "state_dict" in llama_state_dict: + llama_state_dict = llama_state_dict["state_dict"] + + if "state_dict" in lora_state_dict: + lora_state_dict = lora_state_dict["state_dict"] + + # remove prefix model. + if any(k.startswith("model.") for k in llama_state_dict.keys()): + llama_state_dict = { + k.replace("model.", ""): v + for k, v in llama_state_dict.items() + if k.startswith("model.") + } + if any(k.startswith("model.") for k in lora_state_dict.keys()): + lora_state_dict = { + k.replace("model.", ""): v + for k, v in lora_state_dict.items() + if k.startswith("model.") + } + + logger.info(f"Found {len(llama_state_dict)} keys in llama model") + logger.info(f"Found {len(lora_state_dict)} keys in lora model") + + merged_state_dict = llama_state_dict | lora_state_dict + llama_model.load_state_dict(merged_state_dict, strict=True) + logger.info(f"Merged model loaded") + + # Trigger eval mode to merge lora + llama_model.eval() + llama_model.save_pretrained(output, drop_lora=True) + logger.info(f"Saved merged model to {output}, validating") + + new_state_dict = torch.load(output / "model.pth", map_location="cpu") + original_keys = set(llama_state_dict_copy.keys()) + merged_keys = set(new_state_dict.keys()) + + assert original_keys == merged_keys, "Keys should be same" + + for key in original_keys: + diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item() + if diff_l1 != 0: + break + else: + logger.error("Merged model is same as the original model") + exit(1) + + logger.info("Merged model is different from the original model, check passed") + + +if __name__ == "__main__": + merge() diff --git a/tools/llama/quantize.py b/tools/llama/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e629d944b5d1e262f6c0517480980fcac01dad86 --- /dev/null +++ b/tools/llama/quantize.py @@ -0,0 +1,497 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import datetime +import shutil + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import click +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fish_speech.models.text2semantic.llama import find_multiple +from tools.llama.generate import load_model + +##### Quantization Primitives ###### + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + + +##### Weight-only int8 per-channel quantized code ###### + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr( + module, + name, + WeightOnlyInt8Linear(child.in_features, child.out_features), + ) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel( + mod.weight.float(), -128, 127, torch.int8 + ) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer( + "weight", torch.empty((out_features, in_features), dtype=torch.int8) + ) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +##### weight only int4 per channel groupwise quantized code ###### + + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, weight_int4pack, groupsize, scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + ), + ) + elif padding: + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + ), + ) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding: + import torch.nn.functional as F + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + ( + weight_int4pack, + scales_and_zeros, + ) = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to("cuda"), + self.groupsize, + self.inner_k_tiles, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales_and_zeros", + torch.empty( + (in_features // groupsize, out_features, 2), dtype=torch.bfloat16 + ), + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def generate_folder_name(): + now = datetime.datetime.now() + folder_name = now.strftime("%Y%m%d_%H%M%S") + return folder_name + + +@click.command() +@click.option( + "--checkpoint-path", + type=click.Path(path_type=Path, exists=True), + default="checkpoints/fish-speech-1.4", +) +@click.option( + "--mode", type=str, default="int8", help="type of quantization to perform" +) +@click.option( + "--groupsize", type=int, default=128, help="Group size for int4 quantization." +) +@click.option("--timestamp", type=str, default="None", help="When to do quantization") +def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None: + + device = "cpu" + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + model, _ = load_model( + checkpoint_path=checkpoint_path, + device=device, + precision=precision, + compile=False, + ) + vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + now = timestamp if timestamp != "None" else generate_folder_name() + + if mode == "int8": + print( + "Quantizing model weights for int8 weight-only symmetric per-channel quantization" + ) + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int8-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + elif mode == "int4": + print( + "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" + ) + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path + dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}") + shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve())) + if (dst_name / vq_model).exists(): + (dst_name / vq_model).unlink() + quantize_path = dst_name / "model.pth" + + else: + raise ValueError( + f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]" + ) + + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + + +if __name__ == "__main__": + quantize() diff --git a/tools/llama/rebuild_tokenizer.py b/tools/llama/rebuild_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea64fa6788833000c8dc41e3d570dd5b250fb14b --- /dev/null +++ b/tools/llama/rebuild_tokenizer.py @@ -0,0 +1,57 @@ +from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +# Initialize a tokenizer +tokenizer = Tokenizer(models.BPE()) + +# Customize pre-tokenization and decoding +tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) +tokenizer.decoder = decoders.ByteLevel() +tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) + +# Don't train the tokenizer +trainer = trainers.BpeTrainer( + vocab_size=0, + min_frequency=2, + initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), + special_tokens=[ + "<|begin_of_sequence|>", + "<|end_of_sequence|>", + "<|im_start|>", + "<|im_sep|>", # system, user, assistant, etc. + "<|im_end|>", + "<|semantic|>", # audio features + "<|pad|>", + ], +) + +# <|im_start|>user<|im_sep|>...<|im_end|> +# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|> +tokenizer.train_from_iterator([], trainer=trainer) + +print(len(tokenizer.get_vocab())) +x = tokenizer.encode( + "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>" +).ids +print(x, len(x)) +print(tokenizer.decode(x, skip_special_tokens=True)) + + +tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + pad_token="<|pad|>", + bos_token="<|begin_of_sequence|>", + eos_token="<|end_of_sequence|>", +) + +# Try tokenizing a new sequence +sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>" +encoded = tokenizer(sequence).input_ids + +print("Test encoding....") +print(f"\tSentence: {sequence}") +print(f"\tEncoded: {encoded}") +print(f"\tDecoded: {tokenizer.batch_decode(encoded)}") +print(f"\tDecoded: {tokenizer.decode(encoded)}") + +tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True) diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py new file mode 100644 index 0000000000000000000000000000000000000000..67f907bf55283f96f07d89b734403209290421c9 --- /dev/null +++ b/tools/msgpack_api.py @@ -0,0 +1,34 @@ +import httpx +import ormsgpack + +from tools.commons import ServeReferenceAudio, ServeTTSRequest + +# priority: ref_id > references +request = ServeTTSRequest( + text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", + # reference_id="114514", + references=[ + ServeReferenceAudio( + audio=open("lengyue.wav", "rb").read(), + text=open("lengyue.lab", "r", encoding="utf-8").read(), + ) + ], + streaming=True, +) + +with ( + httpx.Client() as client, + open("hello.wav", "wb") as f, +): + with client.stream( + "POST", + "http://127.0.0.1:8080/v1/tts", + content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + timeout=None, + ) as response: + for chunk in response.iter_bytes(): + f.write(chunk) diff --git a/tools/post_api.py b/tools/post_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c20dc455c3ec5a6c69b879537c57cddb13495ce1 --- /dev/null +++ b/tools/post_api.py @@ -0,0 +1,205 @@ +import argparse +import base64 +import wave + +import ormsgpack +import pyaudio +import requests +from pydub import AudioSegment +from pydub.playback import play + +from tools.commons import ServeReferenceAudio, ServeTTSRequest +from tools.file import audio_to_bytes, read_ref_text + + +def parse_args(): + + parser = argparse.ArgumentParser( + description="Send a WAV file and text to a server and receive synthesized audio." + ) + + parser.add_argument( + "--url", + "-u", + type=str, + default="http://127.0.0.1:8080/v1/tts", + help="URL of the server", + ) + parser.add_argument( + "--text", "-t", type=str, required=True, help="Text to be synthesized" + ) + parser.add_argument( + "--reference_id", + "-id", + type=str, + default=None, + help="ID of the reference model o be used for the speech", + ) + parser.add_argument( + "--reference_audio", + "-ra", + type=str, + nargs="+", + default=None, + help="Path to the WAV file", + ) + parser.add_argument( + "--reference_text", + "-rt", + type=str, + nargs="+", + default=None, + help="Reference text for voice synthesis", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="generated_audio", + help="Output audio file name", + ) + parser.add_argument( + "--play", + type=bool, + default=True, + help="Whether to play audio after receiving data", + ) + parser.add_argument("--normalize", type=bool, default=True) + parser.add_argument( + "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" + ) + parser.add_argument("--mp3_bitrate", type=int, default=64) + parser.add_argument("--opus_bitrate", type=int, default=-1000) + parser.add_argument("--latency", type=str, default="normal", help="延迟选项") + parser.add_argument( + "--max_new_tokens", + type=int, + default=1024, + help="Maximum new tokens to generate", + ) + parser.add_argument( + "--chunk_length", type=int, default=100, help="Chunk length for synthesis" + ) + parser.add_argument( + "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis" + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.2, + help="Repetition penalty for synthesis", + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Temperature for sampling" + ) + parser.add_argument( + "--speaker", type=str, default=None, help="Speaker ID for voice synthesis" + ) + parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion") + parser.add_argument( + "--streaming", type=bool, default=False, help="Enable streaming response" + ) + parser.add_argument( + "--channels", type=int, default=1, help="Number of audio channels" + ) + parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio") + + return parser.parse_args() + + +if __name__ == "__main__": + + args = parse_args() + + idstr: str | None = args.reference_id + # priority: ref_id > [{text, audio},...] + if idstr is None: + ref_audios = args.reference_audio + ref_texts = args.reference_text + if ref_audios is None: + byte_audios = [] + else: + byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios] + if ref_texts is None: + ref_texts = [] + else: + ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts] + else: + byte_audios = [] + ref_texts = [] + pass # in api.py + + data = { + "text": args.text, + "references": [ + ServeReferenceAudio(audio=ref_audio, text=ref_text) + for ref_text, ref_audio in zip(ref_texts, byte_audios) + ], + "reference_id": idstr, + "normalize": args.normalize, + "format": args.format, + "mp3_bitrate": args.mp3_bitrate, + "opus_bitrate": args.opus_bitrate, + "max_new_tokens": args.max_new_tokens, + "chunk_length": args.chunk_length, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "temperature": args.temperature, + "speaker": args.speaker, + "emotion": args.emotion, + "streaming": args.streaming, + } + + pydantic_data = ServeTTSRequest(**data) + + response = requests.post( + args.url, + data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + stream=args.streaming, + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + ) + + if response.status_code == 200: + if args.streaming: + p = pyaudio.PyAudio() + audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format + stream = p.open( + format=audio_format, channels=args.channels, rate=args.rate, output=True + ) + + wf = wave.open(f"{args.output}.wav", "wb") + wf.setnchannels(args.channels) + wf.setsampwidth(p.get_sample_size(audio_format)) + wf.setframerate(args.rate) + + stream_stopped_flag = False + + try: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + stream.write(chunk) + wf.writeframesraw(chunk) + else: + if not stream_stopped_flag: + stream.stop_stream() + stream_stopped_flag = True + finally: + stream.close() + p.terminate() + wf.close() + else: + audio_content = response.content + audio_path = f"{args.output}.{args.format}" + with open(audio_path, "wb") as audio_file: + audio_file.write(audio_content) + + audio = AudioSegment.from_file(audio_path, format=args.format) + if args.play: + play(audio) + print(f"Audio has been saved to '{audio_path}'.") + else: + print(f"Request failed with status code {response.status_code}") + print(response.json()) diff --git a/tools/sensevoice/README.md b/tools/sensevoice/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9a2078aa2d96dfafb445384316f2041d9e819e63 --- /dev/null +++ b/tools/sensevoice/README.md @@ -0,0 +1,59 @@ +# FunASR Command Line Interface + +This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files. + +## Requirements + +- Python >= 3.10 +- PyTorch <= 2.3.1 +- ffmpeg, pydub, audio-separator[gpu]. + +## Installation + +Install the required packages: + +```bash +pip install -e .[stable] +``` + +Make sure you have `ffmpeg` installed and available in your `PATH`. + +## Usage + +### Basic Usage + +To run the tool with default settings: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir --save-dir +``` + +## Options + +| Option | Description | +| :-----------------------: | :---------------------------------------------------------------------------: | +| --audio-dir | Directory containing audio or video files. | +| --save-dir | Directory to save processed audio files. | +| --device | Device to use for processing. Options: cuda (default) or cpu. | +| --language | Language of the transcription. Default is auto. | +| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. | +| --punc | Enable punctuation prediction. | +| --denoise | Enable noise reduction (vocal separation). | + +## Example + +To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled: + +```bash +python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise +``` + +## Additional Notes + +- The tool supports `both audio and video files`. Videos will be converted to audio automatically. +- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks. +- The script will automatically create necessary directories in the `--save-dir`. + +## Troubleshooting + +If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency. diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2e186617fe889500d01d95eccdafc5c0248b84 --- /dev/null +++ b/tools/sensevoice/auto_model.py @@ -0,0 +1,573 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import copy +import json +import logging +import os.path +import random +import re +import string +import time + +import numpy as np +import torch +from funasr.download.download_model_from_hub import download_model +from funasr.download.file import download_from_url +from funasr.register import tables +from funasr.train_utils.load_pretrained_model import load_pretrained_model +from funasr.train_utils.set_all_random_seed import set_all_random_seed +from funasr.utils import export_utils, misc +from funasr.utils.load_utils import load_audio_text_image_video, load_bytes +from funasr.utils.misc import deep_update +from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en +from tqdm import tqdm + +from .vad_utils import merge_vad, slice_padding_audio_samples + +try: + from funasr.models.campplus.cluster_backend import ClusterBackend + from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk +except: + pass + + +def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): + """ """ + data_list = [] + key_list = [] + filelist = [".scp", ".txt", ".json", ".jsonl", ".text"] + + chars = string.ascii_letters + string.digits + if isinstance(data_in, str): + if data_in.startswith("http://") or data_in.startswith("https://"): # url + data_in = download_from_url(data_in) + + if isinstance(data_in, str) and os.path.exists( + data_in + ): # wav_path; filelist: wav.scp, file.jsonl;text.txt; + _, file_extension = os.path.splitext(data_in) + file_extension = file_extension.lower() + if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt; + with open(data_in, encoding="utf-8") as fin: + for line in fin: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + if data_in.endswith( + ".jsonl" + ): # file.jsonl: json.dumps({"source": data}) + lines = json.loads(line.strip()) + data = lines["source"] + key = data["key"] if "key" in data else key + else: # filelist, wav.scp, text.txt: id \t data or data + lines = line.strip().split(maxsplit=1) + data = lines[1] if len(lines) > 1 else lines[0] + key = lines[0] if len(lines) > 1 else key + + data_list.append(data) + key_list.append(key) + else: + if key is None: + # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + key = misc.extract_filename_without_extension(data_in) + data_list = [data_in] + key_list = [key] + elif isinstance(data_in, (list, tuple)): + if data_type is not None and isinstance( + data_type, (list, tuple) + ): # mutiple inputs + data_list_tmp = [] + for data_in_i, data_type_i in zip(data_in, data_type): + key_list, data_list_i = prepare_data_iterator( + data_in=data_in_i, data_type=data_type_i + ) + data_list_tmp.append(data_list_i) + data_list = [] + for item in zip(*data_list_tmp): + data_list.append(item) + else: + # [audio sample point, fbank, text] + data_list = data_in + key_list = [] + for data_i in data_in: + if isinstance(data_i, str) and os.path.exists(data_i): + key = misc.extract_filename_without_extension(data_i) + else: + if key is None: + key = "rand_key_" + "".join( + random.choice(chars) for _ in range(13) + ) + key_list.append(key) + + else: # raw text; audio sample point, fbank; bytes + if isinstance(data_in, bytes): # audio bytes + data_in = load_bytes(data_in) + if key is None: + key = "rand_key_" + "".join(random.choice(chars) for _ in range(13)) + data_list = [data_in] + key_list = [key] + + return key_list, data_list + + +class AutoModel: + + def __init__(self, **kwargs): + + try: + from funasr.utils.version_checker import check_for_update + + print( + "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel" + ) + check_for_update(disable=kwargs.get("disable_update", False)) + except: + pass + + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + logging.basicConfig(level=log_level) + + model, kwargs = self.build_model(**kwargs) + + # if vad_model is not None, build vad model else None + vad_model = kwargs.get("vad_model", None) + vad_kwargs = ( + {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {}) + ) + if vad_model is not None: + logging.info("Building VAD model.") + vad_kwargs["model"] = vad_model + vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") + vad_kwargs["device"] = kwargs["device"] + vad_model, vad_kwargs = self.build_model(**vad_kwargs) + + # if punc_model is not None, build punc model else None + punc_model = kwargs.get("punc_model", None) + punc_kwargs = ( + {} + if kwargs.get("punc_kwargs", {}) is None + else kwargs.get("punc_kwargs", {}) + ) + if punc_model is not None: + logging.info("Building punc model.") + punc_kwargs["model"] = punc_model + punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") + punc_kwargs["device"] = kwargs["device"] + punc_model, punc_kwargs = self.build_model(**punc_kwargs) + + # if spk_model is not None, build spk model else None + spk_model = kwargs.get("spk_model", None) + spk_kwargs = ( + {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {}) + ) + if spk_model is not None: + logging.info("Building SPK model.") + spk_kwargs["model"] = spk_model + spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master") + spk_kwargs["device"] = kwargs["device"] + spk_model, spk_kwargs = self.build_model(**spk_kwargs) + self.cb_model = ClusterBackend().to(kwargs["device"]) + spk_mode = kwargs.get("spk_mode", "punc_segment") + if spk_mode not in ["default", "vad_segment", "punc_segment"]: + logging.error( + "spk_mode should be one of default, vad_segment and punc_segment." + ) + self.spk_mode = spk_mode + + self.kwargs = kwargs + self.model = model + self.vad_model = vad_model + self.vad_kwargs = vad_kwargs + self.punc_model = punc_model + self.punc_kwargs = punc_kwargs + self.spk_model = spk_model + self.spk_kwargs = spk_kwargs + self.model_path = kwargs.get("model_path") + + @staticmethod + def build_model(**kwargs): + assert "model" in kwargs + if "model_conf" not in kwargs: + logging.info( + "download models from model hub: {}".format(kwargs.get("hub", "ms")) + ) + kwargs = download_model(**kwargs) + + set_all_random_seed(kwargs.get("seed", 0)) + + device = kwargs.get("device", "cuda") + if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: + device = "cpu" + kwargs["batch_size"] = 1 + kwargs["device"] = device + + torch.set_num_threads(kwargs.get("ncpu", 4)) + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() + if hasattr(tokenizer, "get_vocab") + else kwargs["token_list"] + ) + vocab_size = ( + len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + ) + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + # build model + model_class = tables.model_classes.get(kwargs["model"]) + assert model_class is not None, f'{kwargs["model"]} is not registered' + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + model = model_class(**model_conf, vocab_size=vocab_size) + + # init_param + init_param = kwargs.get("init_param", None) + if init_param is not None: + if os.path.exists(init_param): + logging.info(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=model, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + else: + print(f"error, init_param does not exist!: {init_param}") + + # fp16 + if kwargs.get("fp16", False): + model.to(torch.float16) + elif kwargs.get("bf16", False): + model.to(torch.bfloat16) + model.to(device) + + if not kwargs.get("disable_log", True): + tables.print() + + return model, kwargs + + def __call__(self, *args, **cfg): + kwargs = self.kwargs + deep_update(kwargs, cfg) + res = self.model(*args, kwargs) + return res + + def generate(self, input, input_len=None, **cfg): + if self.vad_model is None: + return self.inference(input, input_len=input_len, **cfg) + + else: + return self.inference_with_vad(input, input_len=input_len, **cfg) + + def inference( + self, input, input_len=None, model=None, kwargs=None, key=None, **cfg + ): + kwargs = self.kwargs if kwargs is None else kwargs + if "cache" in kwargs: + kwargs.pop("cache") + deep_update(kwargs, cfg) + model = self.model if model is None else model + model.eval() + + batch_size = kwargs.get("batch_size", 1) + # if kwargs.get("device", "cpu") == "cpu": + # batch_size = 1 + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key + ) + + speed_stats = {} + asr_result_list = [] + num_samples = len(data_list) + disable_pbar = self.kwargs.get("disable_pbar", False) + pbar = ( + tqdm(colour="blue", total=num_samples, dynamic_ncols=True) + if not disable_pbar + else None + ) + time_speech_total = 0.0 + time_escape_total = 0.0 + for beg_idx in range(0, num_samples, batch_size): + end_idx = min(num_samples, beg_idx + batch_size) + data_batch = data_list[beg_idx:end_idx] + key_batch = key_list[beg_idx:end_idx] + batch = {"data_in": data_batch, "key": key_batch} + + if (end_idx - beg_idx) == 1 and kwargs.get( + "data_type", None + ) == "fbank": # fbank + batch["data_in"] = data_batch[0] + batch["data_lengths"] = input_len + + time1 = time.perf_counter() + with torch.no_grad(): + res = model.inference(**batch, **kwargs) + if isinstance(res, (list, tuple)): + results = res[0] if len(res) > 0 else [{"text": ""}] + meta_data = res[1] if len(res) > 1 else {} + time2 = time.perf_counter() + + asr_result_list.extend(results) + + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() + batch_data_time = meta_data.get("batch_data_time", -1) + time_escape = time2 - time1 + speed_stats["load_data"] = meta_data.get("load_data", 0.0) + speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0) + speed_stats["forward"] = f"{time_escape:0.3f}" + speed_stats["batch_size"] = f"{len(results)}" + speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" + description = f"{speed_stats}, " + if pbar: + pbar.update(end_idx - beg_idx) + pbar.set_description(description) + time_speech_total += batch_data_time + time_escape_total += time_escape + + if pbar: + # pbar.update(1) + pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + torch.cuda.empty_cache() + return asr_result_list + + def vad(self, input, input_len=None, **cfg): + kwargs = self.kwargs + # step.1: compute the vad model + deep_update(self.vad_kwargs, cfg) + beg_vad = time.time() + res = self.inference( + input, + input_len=input_len, + model=self.vad_model, + kwargs=self.vad_kwargs, + **cfg, + ) + end_vad = time.time() + # FIX(gcf): concat the vad clips for sense vocie model for better aed + if cfg.get("merge_vad", False): + for i in range(len(res)): + res[i]["value"] = merge_vad( + res[i]["value"], kwargs.get("merge_length_s", 15) * 1000 + ) + elapsed = end_vad - beg_vad + return elapsed, res + + def inference_with_vadres(self, input, vad_res, input_len=None, **cfg): + + kwargs = self.kwargs + + # step.2 compute asr model + model = self.model + deep_update(kwargs, cfg) + batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1) + batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000 + kwargs["batch_size"] = batch_size + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None) + ) + results_ret_list = [] + time_speech_total_all_samples = 1e-6 + + beg_total = time.time() + pbar_total = ( + tqdm(colour="red", total=len(vad_res), dynamic_ncols=True) + if not kwargs.get("disable_pbar", False) + else None + ) + + for i in range(len(vad_res)): + key = vad_res[i]["key"] + vadsegments = vad_res[i]["value"] + input_i = data_list[i] + fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000 + speech = load_audio_text_image_video( + input_i, fs=fs, audio_fs=kwargs.get("fs", 16000) + ) + speech_lengths = len(speech) + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] + sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) + results_sorted = [] + + if not len(sorted_data): + results_ret_list.append({"key": key, "text": "", "timestamp": []}) + logging.info("decoding, utt: {}, empty speech".format(key)) + continue + + if len(sorted_data) > 0 and len(sorted_data[0]) > 0: + batch_size = max( + batch_size, sorted_data[0][0][1] - sorted_data[0][0][0] + ) + + if kwargs["device"] == "cpu": + batch_size = 0 + + beg_idx = 0 + beg_asr_total = time.time() + time_speech_total_per_sample = speech_lengths / 16000 + time_speech_total_all_samples += time_speech_total_per_sample + + # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True) + + all_segments = [] + max_len_in_batch = 0 + end_idx = 1 + + for j, _ in enumerate(range(0, n)): + # pbar_sample.update(1) + sample_length = sorted_data[j][0][1] - sorted_data[j][0][0] + potential_batch_length = max(max_len_in_batch, sample_length) * ( + j + 1 - beg_idx + ) + # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0] + if ( + j < n - 1 + and sample_length < batch_size_threshold_ms + and potential_batch_length < batch_size + ): + max_len_in_batch = max(max_len_in_batch, sample_length) + end_idx += 1 + continue + + speech_j, speech_lengths_j, intervals = slice_padding_audio_samples( + speech, speech_lengths, sorted_data[beg_idx:end_idx] + ) + results = self.inference( + speech_j, input_len=None, model=model, kwargs=kwargs, **cfg + ) + + for _b in range(len(speech_j)): + results[_b]["interval"] = intervals[_b] + + if self.spk_model is not None: + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] + for _b in range(len(speech_j)): + vad_segments = [ + [ + sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0, + sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0, + np.array(speech_j[_b]), + ] + ] + segments = sv_chunk(vad_segments) + all_segments.extend(segments) + speech_b = [i[2] for i in segments] + spk_res = self.inference( + speech_b, + input_len=None, + model=self.spk_model, + kwargs=kwargs, + **cfg, + ) + results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"] + + beg_idx = end_idx + end_idx += 1 + max_len_in_batch = sample_length + if len(results) < 1: + continue + results_sorted.extend(results) + + # end_asr_total = time.time() + # time_escape_total_per_sample = end_asr_total - beg_asr_total + # pbar_sample.update(1) + # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " + # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + + restored_data = [0] * n + for j in range(n): + index = sorted_data[j][1] + cur = results_sorted[j] + pattern = r"<\|([^|]+)\|>" + emotion_string = re.findall(pattern, cur["text"]) + cur["text"] = re.sub(pattern, "", cur["text"]) + cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string]) + if self.punc_model is not None and len(cur["text"].strip()) > 0: + deep_update(self.punc_kwargs, cfg) + punc_res = self.inference( + cur["text"], + model=self.punc_model, + kwargs=self.punc_kwargs, + **cfg, + ) + cur["text"] = punc_res[0]["text"] + + restored_data[index] = cur + + end_asr_total = time.time() + time_escape_total_per_sample = end_asr_total - beg_asr_total + if pbar_total: + pbar_total.update(1) + pbar_total.set_description( + f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + f"time_speech: {time_speech_total_per_sample: 0.3f}, " + f"time_escape: {time_escape_total_per_sample:0.3f}" + ) + + # end_total = time.time() + # time_escape_total_all_samples = end_total - beg_total + # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, " + # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, " + # f"time_escape_all: {time_escape_total_all_samples:0.3f}") + return restored_data + + def export(self, input=None, **cfg): + """ + + :param input: + :param type: + :param quantize: + :param fallback_num: + :param calib_num: + :param opset_version: + :param cfg: + :return: + """ + + device = cfg.get("device", "cpu") + model = self.model.to(device=device) + kwargs = self.kwargs + deep_update(kwargs, cfg) + kwargs["device"] = device + del kwargs["model"] + model.eval() + + type = kwargs.get("type", "onnx") + + key_list, data_list = prepare_data_iterator( + input, input_len=None, data_type=kwargs.get("data_type", None), key=None + ) + + with torch.no_grad(): + export_dir = export_utils.export(model=model, data_in=data_list, **kwargs) + + return export_dir diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..6789316d5186db69c021758094649553c3638f66 --- /dev/null +++ b/tools/sensevoice/fun_asr.py @@ -0,0 +1,332 @@ +import gc +import os +import re + +from audio_separator.separator import Separator + +os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr" +os.environ["UVR5_CACHE"] = "./.cache/uvr5-models" +import json +import subprocess +from pathlib import Path + +import click +import torch +from loguru import logger +from pydub import AudioSegment +from silero_vad import get_speech_timestamps, load_silero_vad, read_audio +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files +from tools.sensevoice.auto_model import AutoModel + + +def uvr5_cli( + audio_dir: Path, + output_folder: Path, + audio_files: list[Path] | None = None, + output_format: str = "flac", + model: str = "BS-Roformer-Viperx-1297.ckpt", +): + # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"] + sepr = Separator( + model_file_dir=os.environ["UVR5_CACHE"], + output_dir=output_folder, + output_format=output_format, + ) + dictmodel = { + "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt", + "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt", + "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt", + "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt", + } + roformer_model = dictmodel[model] + sepr.load_model(roformer_model) + if audio_files is None: + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + total_files = len(audio_files) + + print(f"{total_files} audio files found") + + res = [] + for audio in tqdm(audio_files, desc="Denoising: "): + file_path = str(audio_dir / audio) + sep_out = sepr.separate(file_path) + if isinstance(sep_out, str): + res.append(sep_out) + elif isinstance(sep_out, list): + res.extend(sep_out) + del sepr + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return res, roformer_model + + +def get_sample_rate(media_path: Path): + result = subprocess.run( + [ + "ffprobe", + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + str(media_path), + ], + capture_output=True, + text=True, + check=True, + ) + media_info = json.loads(result.stdout) + for stream in media_info.get("streams", []): + if stream.get("codec_type") == "audio": + return stream.get("sample_rate") + return "44100" # Default sample rate if not found + + +def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"): + sr = get_sample_rate(src_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + if src_path.resolve() == out_path.resolve(): + output = str(out_path.with_stem(out_path.stem + f"_{sr}")) + else: + output = str(out_path) + subprocess.run( + [ + "ffmpeg", + "-loglevel", + "error", + "-i", + str(src_path), + "-acodec", + "pcm_s16le" if out_fmt == "wav" else "flac", + "-ar", + sr, + "-ac", + "1", + "-y", + output, + ], + check=True, + ) + return out_path + + +def convert_video_to_audio(video_path: Path, audio_dir: Path): + cur_dir = audio_dir / video_path.relative_to(audio_dir).parent + vocals = [ + p + for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*") + if p.suffix in AUDIO_EXTENSIONS + ] + if len(vocals) > 0: + return vocals[0] + audio_path = cur_dir / f"{video_path.stem}.wav" + convert_to_mono(video_path, audio_path) + return audio_path + + +@click.command() +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option( + "--max_single_segment_time", + default=20000, + type=int, + help="Maximum of Output single audio duration(ms)", +) +@click.option("--fsmn-vad/--silero-vad", default=False) +@click.option("--punc/--no-punc", default=False) +@click.option("--denoise/--no-denoise", default=False) +@click.option("--save_emo/--no_save_emo", default=False) +def main( + audio_dir: str, + save_dir: str, + device: str, + language: str, + max_single_segment_time: int, + fsmn_vad: bool, + punc: bool, + denoise: bool, + save_emo: bool, +): + + audios_path = Path(audio_dir) + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + video_files = list_files( + path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True + ) + v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files] + + if denoise: + VOCAL = "_(Vocals)" + original_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem + ] + + _, cur_model = uvr5_cli( + audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files + ) + need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")] + need_remove.extend(original_files) + for _ in need_remove: + _.unlink() + vocal_files = [ + p + for p in audios_path.glob("**/*") + if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem + ] + for f in vocal_files: + fn, ext = f.stem, f.suffix + + v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0]) + if v_pos != -1: + new_fn = fn[: v_pos + len(VOCAL)] + new_f = f.with_name(new_fn + ext) + f = f.rename(new_f) + convert_to_mono(f, f, "flac") + f.unlink() + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + logger.info("Loading / Downloading Funasr model...") + + model_dir = "iic/SenseVoiceSmall" + + vad_model = "fsmn-vad" if fsmn_vad else None + vad_kwargs = {"max_single_segment_time": max_single_segment_time} + punc_model = "ct-punc" if punc else None + + manager = AutoModel( + model=model_dir, + trust_remote_code=False, + vad_model=vad_model, + vad_kwargs=vad_kwargs, + punc_model=punc_model, + device=device, + ) + + if not fsmn_vad and vad_model is None: + vad_model = load_silero_vad() + + logger.info("Model loaded.") + + pattern = re.compile(r"_\d{3}\.") + + for file_path in tqdm(audio_files, desc="Processing audio file"): + + if pattern.search(file_path.name): + # logger.info(f"Skipping {file_path} as it has already been processed.") + continue + + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + cfg = dict( + cache={}, + language=language, # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + batch_size_s=60, + ) + + if fsmn_vad: + elapsed, vad_res = manager.vad(input=str(file_path), **cfg) + else: + wav = read_audio( + str(file_path) + ) # backend (sox, soundfile, or ffmpeg) required! + audio_key = file_path.stem + audio_val = [] + speech_timestamps = get_speech_timestamps( + wav, + vad_model, + max_speech_duration_s=max_single_segment_time // 1000, + return_seconds=True, + ) + + audio_val = [ + [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)] + for timestamp in speech_timestamps + ] + vad_res = [] + vad_res.append(dict(key=audio_key, value=audio_val)) + + res = manager.inference_with_vadres( + input=str(file_path), vad_res=vad_res, **cfg + ) + + for i, info in enumerate(res): + [start_ms, end_ms] = info["interval"] + text = info["text"] + emo = info["emo"] + sliced_audio = audio[start_ms:end_ms] + audio_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}" + ) + sliced_audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}: {text}") + + transcript_save_path = ( + save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab" + ) + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(text) + + if save_emo: + emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo" + with open( + emo_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(emo) + + if audios_path.resolve() == save_path.resolve(): + file_path.unlink() + + +if __name__ == "__main__": + main() + exit(0) + from funasr.utils.postprocess_utils import rich_transcription_postprocess + + # Load the audio file + audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav") + model_dir = "iic/SenseVoiceSmall" + m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") + m.eval() + + res = m.inference( + data_in=f"{kwargs['model_path']}/example/zh.mp3", + language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + ban_emo_unk=False, + **kwargs, + ) + + print(res) + text = rich_transcription_postprocess(res[0][0]["text"]) + print(text) diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3bef75ed8c2841701fff44f7130e91ef8dfdf8cc --- /dev/null +++ b/tools/sensevoice/vad_utils.py @@ -0,0 +1,61 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + + +def slice_padding_fbank(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + for i, segment in enumerate(vad_segments): + + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths[0]) + speech_i = speech[0, bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) + speech_lengths_pad = torch.Tensor(speech_lengths_list).int() + return feats_pad, speech_lengths_pad + + +def slice_padding_audio_samples(speech, speech_lengths, vad_segments): + speech_list = [] + speech_lengths_list = [] + intervals = [] + for i, segment in enumerate(vad_segments): + bed_idx = int(segment[0][0] * 16) + end_idx = min(int(segment[0][1] * 16), speech_lengths) + speech_i = speech[bed_idx:end_idx] + speech_lengths_i = end_idx - bed_idx + speech_list.append(speech_i) + speech_lengths_list.append(speech_lengths_i) + intervals.append([bed_idx // 16, end_idx // 16]) + + return speech_list, speech_lengths_list, intervals + + +def merge_vad(vad_result, max_length=15000, min_length=0): + new_result = [] + if len(vad_result) <= 1: + return vad_result + time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result] + time_step = sorted(list(set(time_step))) + if len(time_step) == 0: + return [] + bg = 0 + for i in range(len(time_step) - 1): + time = time_step[i] + if time_step[i + 1] - bg < max_length: + continue + if time - bg > min_length: + new_result.append([bg, time]) + # if time - bg < max_length * 1.5: + # new_result.append([bg, time]) + # else: + # split_num = int(time - bg) // max_length + 1 + # spl_l = int(time - bg) // split_num + # for j in range(split_num): + # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l]) + bg = time + new_result.append([bg, time_step[-1]]) + return new_result diff --git a/tools/smart_pad.py b/tools/smart_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..de9dc154f26b2869a7e34f7d4cd95db741ee4c6a --- /dev/null +++ b/tools/smart_pad.py @@ -0,0 +1,60 @@ +import random +from multiprocessing import Pool +from pathlib import Path + +import click +import librosa +import torch.nn.functional as F +import torchaudio +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files + +threshold = 10 ** (-50 / 20.0) + + +def process(file): + waveform, sample_rate = torchaudio.load(str(file), backend="sox") + if waveform.size(0) > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + loudness = librosa.feature.rms( + y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True + )[0] + + for i in range(len(loudness) - 1, 0, -1): + if loudness[i] > threshold: + break + + end_silent_time = (len(loudness) - i) * 512 / sample_rate + + if end_silent_time <= 0.3: + random_time = random.uniform(0.3, 0.7) - end_silent_time + waveform = F.pad( + waveform, (0, int(random_time * sample_rate)), mode="constant", value=0 + ) + + for i in range(len(loudness)): + if loudness[i] > threshold: + break + + start_silent_time = i * 512 / sample_rate + + if start_silent_time > 0.02: + waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :] + + torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate) + + +@click.command() +@click.argument("source", type=Path) +@click.option("--num-workers", type=int, default=12) +def main(source, num_workers): + files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True)) + + with Pool(num_workers) as p: + list(tqdm(p.imap_unordered(process, files), total=len(files))) + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/__pycache__/inference.cpython-310.pyc b/tools/vqgan/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4654556683db8aaa2c482d8adfb4294f5cd59e8 Binary files /dev/null and b/tools/vqgan/__pycache__/inference.cpython-310.pyc differ diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py new file mode 100644 index 0000000000000000000000000000000000000000..d24a5f39566c47ea0cb1fc506d463e9c95c3efbc --- /dev/null +++ b/tools/vqgan/create_train_split.py @@ -0,0 +1,83 @@ +import math +from pathlib import Path +from random import Random + +import click +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist + + +@click.command() +@click.argument("root", type=click.Path(exists=True, path_type=Path)) +@click.option("--val-ratio", type=float, default=None) +@click.option("--val-count", type=int, default=None) +@click.option("--filelist", default=None, type=Path) +@click.option("--min-duration", default=None, type=float) +@click.option("--max-duration", default=None, type=float) +def main(root, val_ratio, val_count, filelist, min_duration, max_duration): + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True) + + if min_duration is None and max_duration is None: + filtered_files = list(map(str, [file.relative_to(root) for file in files])) + else: + filtered_files = [] + for file in tqdm(files): + try: + audio = AudioSegment.from_file(str(file)) + duration = len(audio) / 1000.0 + + if min_duration is not None and duration < min_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}" + ) + continue + + if max_duration is not None and duration > max_duration: + logger.info( + f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}" + ) + continue + + filtered_files.append(str(file.relative_to(root))) + except Exception as e: + logger.info(f"Error processing {file}: {e}") + + logger.info( + f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering" + ) + + Random(42).shuffle(filtered_files) + + if val_count is None and val_ratio is None: + logger.info("Validation ratio and count not specified, using min(20%, 100)") + val_size = min(100, math.ceil(len(filtered_files) * 0.2)) + elif val_count is not None and val_ratio is not None: + logger.error("Cannot specify both val_count and val_ratio") + return + elif val_count is not None: + if val_count < 1 or val_count > len(filtered_files): + logger.error("val_count must be between 1 and number of files") + return + val_size = val_count + else: + val_size = math.ceil(len(filtered_files) * val_ratio) + + logger.info(f"Using {val_size} files for validation") + + with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[val_size:])) + + with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f: + f.write("\n".join(filtered_files[:val_size])) + + logger.info("Done") + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..c24eb3f46ab57fb02930f233a67299cb31c7d7ba --- /dev/null +++ b/tools/vqgan/extract_vq.py @@ -0,0 +1,227 @@ +import os +import subprocess as sp +import sys +import time +from datetime import timedelta +from functools import lru_cache +from pathlib import Path +from random import Random + +import click +import numpy as np +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from lightning import LightningModule +from loguru import logger +from omegaconf import OmegaConf + +from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) +# This file is used to convert the audio files to text files using the Whisper model. +# It's mainly used to generate the training data for the VQ model. + + +RANK = int(os.environ.get("SLURM_PROCID", 0)) +WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1)) + +logger_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{name}:{function}:{line} | " + "{extra[rank]} - {message}" +) +logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"}) +logger.remove() +logger.add(sys.stderr, format=logger_format) + + +@lru_cache(maxsize=1) +def get_model( + config_name: str = "firefly_gan_vq", + checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + device: str | torch.device = "cuda", +): + with initialize(version_base="1.3", config_path="../../fish_speech/configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, + map_location=device, + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + + logger.info(f"Loaded model") + return model + + +@torch.inference_mode() +def process_batch(files: list[Path], model) -> float: + wavs = [] + audio_lengths = [] + new_files = [] + max_length = total_time = 0 + + for file in files: + try: + wav, sr = torchaudio.load( + str(file), backend="sox" if sys.platform == "linux" else "soundfile" + ) # Need to install libsox-dev + except Exception as e: + logger.error(f"Error reading {file}: {e}") + continue + + if wav.shape[0] > 1: + wav = wav.mean(dim=0, keepdim=True) + + wav = torchaudio.functional.resample( + wav.cuda(), sr, model.spec_transform.sample_rate + )[0] + total_time += len(wav) / model.spec_transform.sample_rate + max_length = max(max_length, len(wav)) + + wavs.append(wav) + audio_lengths.append(len(wav)) + new_files.append(file) + + files = new_files + + # Pad to max length + for i, wav in enumerate(wavs): + wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant") + + audios = torch.stack(wavs, dim=0)[:, None] + audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long) + + # Calculate lengths + indices, feature_lengths = model.encode(audios, audio_lengths) + + # Save to disk + outputs = indices.cpu().numpy() + + for file, length, feature, audio_length in zip( + files, feature_lengths, outputs, audio_lengths + ): + feature = feature[:, :length] + + # (T,) + with open(file.with_suffix(".npy"), "wb") as f: + np.save(f, feature) + + return total_time + + +@click.command() +@click.argument("folder") +@click.option("--num-workers", default=1) +@click.option("--config-name", default="firefly_gan_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +) +@click.option("--batch-size", default=64) +@click.option("--filelist", default=None, type=Path) +def main( + folder: str, + num_workers: int, + config_name: str, + checkpoint_path: str, + batch_size: int, + filelist: Path, +): + if num_workers > 1 and WORLD_SIZE != num_workers: + assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both" + + logger.info(f"Spawning {num_workers} workers") + + if torch.cuda.is_available(): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if visible_devices is None: + visible_devices = list(range(torch.cuda.device_count())) + else: + visible_devices = visible_devices.split(",") + else: + # Set to empty string to avoid using GPU + visible_devices = [""] + + processes = [] + for i in range(num_workers): + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)]) + env["SLURM_PROCID"] = str(i) + env["SLURM_NTASKS"] = str(num_workers) + + processes.append( + sp.Popen( + [sys.executable] + sys.argv.copy(), + env=env, + ) + ) + + for p in processes: + p.wait() + + logger.info(f"All workers finished") + return + + # This is a worker + logger.info(f"Starting worker") + if filelist: + files = [i[0] for i in load_filelist(filelist)] + else: + files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False) + + print(f"Found {len(files)} files") + files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()] + + total_files = len(files) + files = files[RANK::WORLD_SIZE] + logger.info(f"Processing {len(files)}/{total_files} files") + + # Batch processing + total_time = 0 + begin_time = time.time() + processed_files = 0 + model = get_model(config_name, checkpoint_path) + + for n_batch, idx in enumerate(range(0, len(files), batch_size)): + batch = files[idx : idx + batch_size] + batch_time = process_batch(batch, model) + + total_time += batch_time + processed_files += len(batch) + + if (n_batch + 1) % 10 == 0: + eta = ( + (time.time() - begin_time) + / processed_files + * (len(files) - processed_files) + ) + logger.info( + f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, " + + f"ETA: {timedelta(seconds=round(eta))}s" + ) + + logger.info( + f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio" + ) + + +if __name__ == "__main__": + main() diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bc7531c41455c346109bdaaa43dafc1e3508a4 --- /dev/null +++ b/tools/vqgan/inference.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import click +import hydra +import numpy as np +import soundfile as sf +import torch +import torchaudio +from hydra import compose, initialize +from hydra.utils import instantiate +from loguru import logger +from omegaconf import OmegaConf + +from tools.file import AUDIO_EXTENSIONS + +# register eval resolver +OmegaConf.register_new_resolver("eval", eval) + + +def load_model(config_name, checkpoint_path, device="cuda"): + hydra.core.global_hydra.GlobalHydra.instance().clear() + with initialize(version_base="1.3", config_path="../../fish_speech/configs"): + cfg = compose(config_name=config_name) + + model = instantiate(cfg) + state_dict = torch.load( + checkpoint_path, + map_location=device, + ) + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + if any("generator" in k for k in state_dict): + state_dict = { + k.replace("generator.", ""): v + for k, v in state_dict.items() + if "generator." in k + } + + result = model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + + logger.info(f"Loaded model: {result}") + return model + + +@torch.no_grad() +@click.command() +@click.option( + "--input-path", + "-i", + default="test.wav", + type=click.Path(exists=True, path_type=Path), +) +@click.option( + "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path) +) +@click.option("--config-name", default="firefly_gan_vq") +@click.option( + "--checkpoint-path", + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", +) +@click.option( + "--device", + "-d", + default="cuda", +) +def main(input_path, output_path, config_name, checkpoint_path, device): + model = load_model(config_name, checkpoint_path, device=device) + + if input_path.suffix in AUDIO_EXTENSIONS: + logger.info(f"Processing in-place reconstruction of {input_path}") + + # Load audio + audio, sr = torchaudio.load(str(input_path)) + if audio.shape[0] > 1: + audio = audio.mean(0, keepdim=True) + audio = torchaudio.functional.resample( + audio, sr, model.spec_transform.sample_rate + ) + + audios = audio[None].to(device) + logger.info( + f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds" + ) + + # VQ Encoder + audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long) + indices = model.encode(audios, audio_lengths)[0][0] + + logger.info(f"Generated indices of shape {indices.shape}") + + # Save indices + np.save(output_path.with_suffix(".npy"), indices.cpu().numpy()) + elif input_path.suffix == ".npy": + logger.info(f"Processing precomputed indices from {input_path}") + indices = np.load(input_path) + indices = torch.from_numpy(indices).to(device).long() + assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}" + else: + raise ValueError(f"Unknown input type: {input_path}") + + # Restore + feature_lengths = torch.tensor([indices.shape[1]], device=device) + fake_audios, _ = model.decode( + indices=indices[None], feature_lengths=feature_lengths + ) + audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate + + logger.info( + f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}" + ) + + # Save audio + fake_audio = fake_audios[0, 0].float().cpu().numpy() + sf.write(output_path, fake_audio, model.spec_transform.sample_rate) + logger.info(f"Saved audio to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/webui.py b/tools/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..cff155d48967b4d3980e280096cafc511009a737 --- /dev/null +++ b/tools/webui.py @@ -0,0 +1,485 @@ +import gc +import html +import io +import os +import queue +import wave +from argparse import ArgumentParser +from functools import partial +from pathlib import Path + +import gradio as gr +import librosa +import numpy as np +import pyrootutils +import torch +from loguru import logger +from transformers import AutoTokenizer + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + + +from fish_speech.i18n import i18n +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText +from fish_speech.utils import autocast_exclude_mps +from tools.api import decode_vq_tokens, encode_reference +from tools.llama.generate import ( + GenerateRequest, + GenerateResponse, + WrappedGenerateResponse, + launch_thread_safe_queue, +) +from tools.vqgan.inference import load_model as load_decoder_model + +# Make einx happy +os.environ["EINX_FILTER_TRACEBACK"] = "false" + + +HEADER_MD = f"""# Fish Speech + +{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")} + +{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")} + +{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")} + +{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")} +""" + +TEXTBOX_PLACEHOLDER = i18n("Put your text here.") +SPACE_IMPORTED = False + + +def build_html_error_message(error): + return f""" +
+ {html.escape(str(error))} +
+ """ + + +@torch.inference_mode() +def inference( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + streaming=False, +): + if args.max_gradio_length > 0 and len(text) > args.max_gradio_length: + return ( + None, + None, + i18n("Text is too long, please keep it under {} characters.").format( + args.max_gradio_length + ), + ) + + # Parse reference audio aka prompt + prompt_tokens = encode_reference( + decoder_model=decoder_model, + reference_audio=reference_audio, + enable_reference_audio=enable_reference_audio, + ) + + # LLAMA Inference + request = dict( + device=decoder_model.device, + max_new_tokens=max_new_tokens, + text=text, + top_p=top_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + compile=args.compile, + iterative_prompt=chunk_length > 0, + chunk_length=chunk_length, + max_length=2048, + prompt_tokens=prompt_tokens if enable_reference_audio else None, + prompt_text=reference_text if enable_reference_audio else None, + ) + + response_queue = queue.Queue() + llama_queue.put( + GenerateRequest( + request=request, + response_queue=response_queue, + ) + ) + + if streaming: + yield wav_chunk_header(), None, None + + segments = [] + + while True: + result: WrappedGenerateResponse = response_queue.get() + if result.status == "error": + yield None, None, build_html_error_message(result.response) + break + + result: GenerateResponse = result.response + if result.action == "next": + break + + with autocast_exclude_mps( + device_type=decoder_model.device.type, dtype=args.precision + ): + fake_audios = decode_vq_tokens( + decoder_model=decoder_model, + codes=result.codes, + ) + + fake_audios = fake_audios.float().cpu().numpy() + segments.append(fake_audios) + + if streaming: + yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None + + if len(segments) == 0: + return ( + None, + None, + build_html_error_message( + i18n("No audio generated, please check the input text.") + ), + ) + + # No matter streaming or not, we need to return the final audio + audio = np.concatenate(segments, axis=0) + yield None, (decoder_model.spec_transform.sample_rate, audio), None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + +inference_stream = partial(inference, streaming=True) + +n_audios = 4 + +global_audio_list = [] +global_error_list = [] + + +def inference_wrapper( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + batch_infer_num, +): + audios = [] + errors = [] + + for _ in range(batch_infer_num): + result = inference( + text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + ) + + _, audio_data, error_message = next(result) + + audios.append( + gr.Audio(value=audio_data if audio_data else None, visible=True), + ) + errors.append( + gr.HTML(value=error_message if error_message else None, visible=True), + ) + + for _ in range(batch_infer_num, n_audios): + audios.append( + gr.Audio(value=None, visible=False), + ) + errors.append( + gr.HTML(value=None, visible=False), + ) + + return None, *audios, *errors + + +def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): + buffer = io.BytesIO() + + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(channels) + wav_file.setsampwidth(bit_depth // 8) + wav_file.setframerate(sample_rate) + + wav_header_bytes = buffer.getvalue() + buffer.close() + return wav_header_bytes + + +def normalize_text(user_input, use_normalization): + if use_normalization: + return ChnNormedText(raw_text=user_input).normalize() + else: + return user_input + + +asr_model = None + + +def build_app(): + with gr.Blocks(theme=gr.themes.Base()) as app: + gr.Markdown(HEADER_MD) + + # Use light theme by default + app.load( + None, + None, + js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}" + % args.theme, + ) + + # Inference + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10 + ) + refined_text = gr.Textbox( + label=i18n("Realtime Transform Text"), + placeholder=i18n( + "Normalization Result Preview (Currently Only Chinese)" + ), + lines=5, + interactive=False, + ) + + with gr.Row(): + if_refine_text = gr.Checkbox( + label=i18n("Text Normalization"), + value=False, + scale=1, + ) + + with gr.Row(): + with gr.Tab(label=i18n("Advanced Config")): + chunk_length = gr.Slider( + label=i18n("Iterative Prompt Length, 0 means off"), + minimum=50, + maximum=300, + value=200, + step=8, + ) + + max_new_tokens = gr.Slider( + label=i18n("Maximum tokens per batch, 0 means no limit"), + minimum=0, + maximum=2048, + value=1024, # 0 means no limit + step=8, + ) + + top_p = gr.Slider( + label="Top-P", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + repetition_penalty = gr.Slider( + label=i18n("Repetition Penalty"), + minimum=1, + maximum=1.5, + value=1.2, + step=0.01, + ) + + temperature = gr.Slider( + label="Temperature", + minimum=0.6, + maximum=0.9, + value=0.7, + step=0.01, + ) + + with gr.Tab(label=i18n("Reference Audio")): + gr.Markdown( + i18n( + "5 to 10 seconds of reference audio, useful for specifying speaker." + ) + ) + + enable_reference_audio = gr.Checkbox( + label=i18n("Enable Reference Audio"), + ) + reference_audio = gr.Audio( + label=i18n("Reference Audio"), + type="filepath", + ) + with gr.Row(): + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) + with gr.Tab(label=i18n("Batch Inference")): + batch_infer_num = gr.Slider( + label="Batch infer nums", + minimum=1, + maximum=n_audios, + step=1, + value=1, + ) + + with gr.Column(scale=3): + for _ in range(n_audios): + with gr.Row(): + error = gr.HTML( + label=i18n("Error Message"), + visible=True if _ == 0 else False, + ) + global_error_list.append(error) + with gr.Row(): + audio = gr.Audio( + label=i18n("Generated Audio"), + type="numpy", + interactive=False, + visible=True if _ == 0 else False, + ) + global_audio_list.append(audio) + + with gr.Row(): + stream_audio = gr.Audio( + label=i18n("Streaming Audio"), + streaming=True, + autoplay=True, + interactive=False, + show_download_button=True, + ) + with gr.Row(): + with gr.Column(scale=3): + generate = gr.Button( + value="\U0001F3A7 " + i18n("Generate"), variant="primary" + ) + generate_stream = gr.Button( + value="\U0001F3A7 " + i18n("Streaming Generate"), + variant="primary", + ) + + text.input( + fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text] + ) + + # # Submit + generate.click( + inference_wrapper, + [ + refined_text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + batch_infer_num, + ], + [stream_audio, *global_audio_list, *global_error_list], + concurrency_limit=1, + ) + + generate_stream.click( + inference_stream, + [ + refined_text, + enable_reference_audio, + reference_audio, + reference_text, + max_new_tokens, + chunk_length, + top_p, + repetition_penalty, + temperature, + ], + [stream_audio, global_audio_list[0], global_error_list[0]], + concurrency_limit=10, + ) + return app + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "--llama-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.4", + ) + parser.add_argument( + "--decoder-checkpoint-path", + type=Path, + default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", + ) + parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--half", action="store_true") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--max-gradio-length", type=int, default=0) + parser.add_argument("--theme", type=str, default="light") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + args.precision = torch.half if args.half else torch.bfloat16 + + logger.info("Loading Llama model...") + llama_queue = launch_thread_safe_queue( + checkpoint_path=args.llama_checkpoint_path, + device=args.device, + precision=args.precision, + compile=args.compile, + ) + logger.info("Llama model loaded, loading VQ-GAN model...") + + decoder_model = load_decoder_model( + config_name=args.decoder_config_name, + checkpoint_path=args.decoder_checkpoint_path, + device=args.device, + ) + + logger.info("Decoder model loaded, warming up...") + + # Dry run to check if the model is loaded correctly and avoid the first-time latency + list( + inference( + text="Hello, world!", + enable_reference_audio=False, + reference_audio=None, + reference_text="", + max_new_tokens=1024, + chunk_length=200, + top_p=0.7, + repetition_penalty=1.2, + temperature=0.7, + ) + ) + + logger.info("Warming up done, launching the web UI...") + + app = build_app() + app.launch(show_api=True) diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..42e7de8a185880d3f2afd368d6df3429488465a4 --- /dev/null +++ b/tools/whisper_asr.py @@ -0,0 +1,176 @@ +""" +Used to transcribe all audio files in one folder into another folder. +e.g. +Directory structure: +--pre_data_root +----SP_1 +------01.wav +------02.wav +------...... +----SP_2 +------01.wav +------02.wav +------...... +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1 +to transcribe the first speaker. + +Use +python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2 +to transcribe the second speaker. + +Note: Be aware of your audio sample rate, which defaults to 44.1kHz. +""" + +import re +from pathlib import Path + +import click +import soundfile as sf +from faster_whisper import WhisperModel +from loguru import logger +from pydub import AudioSegment +from tqdm import tqdm + +from tools.file import AUDIO_EXTENSIONS, list_files + + +@click.command() +@click.option("--model-size", default="large-v3", help="Size of the Whisper model") +@click.option( + "--compute-type", + default="float16", + help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]", +) +@click.option("--audio-dir", required=True, help="Directory containing audio files") +@click.option( + "--save-dir", required=True, help="Directory to save processed audio files" +) +@click.option( + "--sample-rate", + default=44100, + type=int, + help="Output sample rate, default to input sample rate", +) +@click.option("--device", default="cuda", help="Device to use [cuda / cpu]") +@click.option("--language", default="auto", help="Language of the transcription") +@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing") +def main( + model_size, + compute_type, + audio_dir, + save_dir, + sample_rate, + device, + language, + initial_prompt, +): + logger.info("Loading / Downloading Faster Whisper model...") + + model = WhisperModel( + model_size, + device=device, + compute_type=compute_type, + download_root="faster_whisper", + ) + + logger.info("Model loaded.") + + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + audio_files = list_files( + path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True + ) + + for file_path in tqdm(audio_files, desc="Processing audio file"): + file_stem = file_path.stem + file_suffix = file_path.suffix + + rel_path = Path(file_path).relative_to(audio_dir) + (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) + + audio = AudioSegment.from_file(file_path) + + segments, info = model.transcribe( + file_path, + beam_size=5, + language=None if language == "auto" else language, + initial_prompt=initial_prompt, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + whole_text = None + for segment in segments: + id, start, end, text = ( + segment.id, + segment.start, + segment.end, + segment.text, + ) + print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text)) + if not whole_text: + whole_text = text + else: + whole_text += ", " + text + + whole_text += "." + + audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}" + audio.export(audio_save_path, format=file_suffix[1:]) + print(f"Exported {audio_save_path}") + + transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab" + with open( + transcript_save_path, + "w", + encoding="utf-8", + ) as f: + f.write(whole_text) + + +if __name__ == "__main__": + main() + exit(0) + + audio = AudioSegment.from_wav( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav" + ) + + model_size = "large-v3" + + model = WhisperModel( + model_size, + device="cuda", + compute_type="float16", + download_root="faster_whisper", + ) + + segments, info = model.transcribe( + r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav", + beam_size=5, + ) + + print( + "Detected language '%s' with probability %f" + % (info.language, info.language_probability) + ) + print("Total len(ms): ", len(audio)) + + for i, segment in enumerate(segments): + print( + "Segment %03d [%.2fs -> %.2fs] %s" + % (i, segment.start, segment.end, segment.text) + ) + start_ms = int(segment.start * 1000) + end_ms = int(segment.end * 1000) + segment_audio = audio[start_ms:end_ms] + segment_audio.export(f"segment_{i:03d}.wav", format="wav") + print(f"Exported segment_{i:03d}.wav") + + print("All segments have been exported.")