+```
+
+## Options
+
+| Option | Description |
+| :-----------------------: | :---------------------------------------------------------------------------: |
+| --audio-dir | Directory containing audio or video files. |
+| --save-dir | Directory to save processed audio files. |
+| --device | Device to use for processing. Options: cuda (default) or cpu. |
+| --language | Language of the transcription. Default is auto. |
+| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
+| --punc | Enable punctuation prediction. |
+| --denoise | Enable noise reduction (vocal separation). |
+
+## Example
+
+To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
+
+```bash
+python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
+```
+
+## Additional Notes
+
+- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
+- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
+- The script will automatically create necessary directories in the `--save-dir`.
+
+## Troubleshooting
+
+If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
diff --git a/tools/sensevoice/__init__.py b/tools/sensevoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tools/sensevoice/auto_model.py b/tools/sensevoice/auto_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2e186617fe889500d01d95eccdafc5c0248b84
--- /dev/null
+++ b/tools/sensevoice/auto_model.py
@@ -0,0 +1,573 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import copy
+import json
+import logging
+import os.path
+import random
+import re
+import string
+import time
+
+import numpy as np
+import torch
+from funasr.download.download_model_from_hub import download_model
+from funasr.download.file import download_from_url
+from funasr.register import tables
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import export_utils, misc
+from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
+from funasr.utils.misc import deep_update
+from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
+from tqdm import tqdm
+
+from .vad_utils import merge_vad, slice_padding_audio_samples
+
+try:
+ from funasr.models.campplus.cluster_backend import ClusterBackend
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
+except:
+ pass
+
+
+def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
+ """ """
+ data_list = []
+ key_list = []
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
+
+ chars = string.ascii_letters + string.digits
+ if isinstance(data_in, str):
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
+ data_in = download_from_url(data_in)
+
+ if isinstance(data_in, str) and os.path.exists(
+ data_in
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
+ _, file_extension = os.path.splitext(data_in)
+ file_extension = file_extension.lower()
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
+ with open(data_in, encoding="utf-8") as fin:
+ for line in fin:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ if data_in.endswith(
+ ".jsonl"
+ ): # file.jsonl: json.dumps({"source": data})
+ lines = json.loads(line.strip())
+ data = lines["source"]
+ key = data["key"] if "key" in data else key
+ else: # filelist, wav.scp, text.txt: id \t data or data
+ lines = line.strip().split(maxsplit=1)
+ data = lines[1] if len(lines) > 1 else lines[0]
+ key = lines[0] if len(lines) > 1 else key
+
+ data_list.append(data)
+ key_list.append(key)
+ else:
+ if key is None:
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
+ data_list = [data_in]
+ key_list = [key]
+ elif isinstance(data_in, (list, tuple)):
+ if data_type is not None and isinstance(
+ data_type, (list, tuple)
+ ): # mutiple inputs
+ data_list_tmp = []
+ for data_in_i, data_type_i in zip(data_in, data_type):
+ key_list, data_list_i = prepare_data_iterator(
+ data_in=data_in_i, data_type=data_type_i
+ )
+ data_list_tmp.append(data_list_i)
+ data_list = []
+ for item in zip(*data_list_tmp):
+ data_list.append(item)
+ else:
+ # [audio sample point, fbank, text]
+ data_list = data_in
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ if key is None:
+ key = "rand_key_" + "".join(
+ random.choice(chars) for _ in range(13)
+ )
+ key_list.append(key)
+
+ else: # raw text; audio sample point, fbank; bytes
+ if isinstance(data_in, bytes): # audio bytes
+ data_in = load_bytes(data_in)
+ if key is None:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ data_list = [data_in]
+ key_list = [key]
+
+ return key_list, data_list
+
+
+class AutoModel:
+
+ def __init__(self, **kwargs):
+
+ try:
+ from funasr.utils.version_checker import check_for_update
+
+ print(
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
+ )
+ check_for_update(disable=kwargs.get("disable_update", False))
+ except:
+ pass
+
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+ logging.basicConfig(level=log_level)
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = (
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
+ )
+ if vad_model is not None:
+ logging.info("Building VAD model.")
+ vad_kwargs["model"] = vad_model
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
+ vad_kwargs["device"] = kwargs["device"]
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = (
+ {}
+ if kwargs.get("punc_kwargs", {}) is None
+ else kwargs.get("punc_kwargs", {})
+ )
+ if punc_model is not None:
+ logging.info("Building punc model.")
+ punc_kwargs["model"] = punc_model
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
+ punc_kwargs["device"] = kwargs["device"]
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ # if spk_model is not None, build spk model else None
+ spk_model = kwargs.get("spk_model", None)
+ spk_kwargs = (
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
+ )
+ if spk_model is not None:
+ logging.info("Building SPK model.")
+ spk_kwargs["model"] = spk_model
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
+ spk_kwargs["device"] = kwargs["device"]
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
+ self.cb_model = ClusterBackend().to(kwargs["device"])
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
+ logging.error(
+ "spk_mode should be one of default, vad_segment and punc_segment."
+ )
+ self.spk_mode = spk_mode
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+ self.spk_model = spk_model
+ self.spk_kwargs = spk_kwargs
+ self.model_path = kwargs.get("model_path")
+
+ @staticmethod
+ def build_model(**kwargs):
+ assert "model" in kwargs
+ if "model_conf" not in kwargs:
+ logging.info(
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
+ )
+ kwargs = download_model(**kwargs)
+
+ set_all_random_seed(kwargs.get("seed", 0))
+
+ device = kwargs.get("device", "cuda")
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
+ device = "cpu"
+ kwargs["batch_size"] = 1
+ kwargs["device"] = device
+
+ torch.set_num_threads(kwargs.get("ncpu", 4))
+
+ # build tokenizer
+ tokenizer = kwargs.get("tokenizer", None)
+ if tokenizer is not None:
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
+ kwargs["token_list"] = (
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ )
+ kwargs["token_list"] = (
+ tokenizer.get_vocab()
+ if hasattr(tokenizer, "get_vocab")
+ else kwargs["token_list"]
+ )
+ vocab_size = (
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
+ )
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ else:
+ vocab_size = -1
+ kwargs["tokenizer"] = tokenizer
+
+ # build frontend
+ frontend = kwargs.get("frontend", None)
+ kwargs["input_size"] = None
+ if frontend is not None:
+ frontend_class = tables.frontend_classes.get(frontend)
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
+ kwargs["input_size"] = (
+ frontend.output_size() if hasattr(frontend, "output_size") else None
+ )
+ kwargs["frontend"] = frontend
+ # build model
+ model_class = tables.model_classes.get(kwargs["model"])
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
+ model_conf = {}
+ deep_update(model_conf, kwargs.get("model_conf", {}))
+ deep_update(model_conf, kwargs)
+ model = model_class(**model_conf, vocab_size=vocab_size)
+
+ # init_param
+ init_param = kwargs.get("init_param", None)
+ if init_param is not None:
+ if os.path.exists(init_param):
+ logging.info(f"Loading pretrained params from {init_param}")
+ load_pretrained_model(
+ model=model,
+ path=init_param,
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+ oss_bucket=kwargs.get("oss_bucket", None),
+ scope_map=kwargs.get("scope_map", []),
+ excludes=kwargs.get("excludes", None),
+ )
+ else:
+ print(f"error, init_param does not exist!: {init_param}")
+
+ # fp16
+ if kwargs.get("fp16", False):
+ model.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ model.to(torch.bfloat16)
+ model.to(device)
+
+ if not kwargs.get("disable_log", True):
+ tables.print()
+
+ return model, kwargs
+
+ def __call__(self, *args, **cfg):
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ res = self.model(*args, kwargs)
+ return res
+
+ def generate(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.inference(input, input_len=input_len, **cfg)
+
+ else:
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
+
+ def inference(
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
+ ):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ if "cache" in kwargs:
+ kwargs.pop("cache")
+ deep_update(kwargs, cfg)
+ model = self.model if model is None else model
+ model.eval()
+
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
+ )
+
+ speed_stats = {}
+ asr_result_list = []
+ num_samples = len(data_list)
+ disable_pbar = self.kwargs.get("disable_pbar", False)
+ pbar = (
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ if not disable_pbar
+ else None
+ )
+ time_speech_total = 0.0
+ time_escape_total = 0.0
+ for beg_idx in range(0, num_samples, batch_size):
+ end_idx = min(num_samples, beg_idx + batch_size)
+ data_batch = data_list[beg_idx:end_idx]
+ key_batch = key_list[beg_idx:end_idx]
+ batch = {"data_in": data_batch, "key": key_batch}
+
+ if (end_idx - beg_idx) == 1 and kwargs.get(
+ "data_type", None
+ ) == "fbank": # fbank
+ batch["data_in"] = data_batch[0]
+ batch["data_lengths"] = input_len
+
+ time1 = time.perf_counter()
+ with torch.no_grad():
+ res = model.inference(**batch, **kwargs)
+ if isinstance(res, (list, tuple)):
+ results = res[0] if len(res) > 0 else [{"text": ""}]
+ meta_data = res[1] if len(res) > 1 else {}
+ time2 = time.perf_counter()
+
+ asr_result_list.extend(results)
+
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+ batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
+ description = f"{speed_stats}, "
+ if pbar:
+ pbar.update(end_idx - beg_idx)
+ pbar.set_description(description)
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ if pbar:
+ # pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
+ torch.cuda.empty_cache()
+ return asr_result_list
+
+ def vad(self, input, input_len=None, **cfg):
+ kwargs = self.kwargs
+ # step.1: compute the vad model
+ deep_update(self.vad_kwargs, cfg)
+ beg_vad = time.time()
+ res = self.inference(
+ input,
+ input_len=input_len,
+ model=self.vad_model,
+ kwargs=self.vad_kwargs,
+ **cfg,
+ )
+ end_vad = time.time()
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
+ if cfg.get("merge_vad", False):
+ for i in range(len(res)):
+ res[i]["value"] = merge_vad(
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
+ )
+ elapsed = end_vad - beg_vad
+ return elapsed, res
+
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
+
+ kwargs = self.kwargs
+
+ # step.2 compute asr model
+ model = self.model
+ deep_update(kwargs, cfg)
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
+ kwargs["batch_size"] = batch_size
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
+ )
+ results_ret_list = []
+ time_speech_total_all_samples = 1e-6
+
+ beg_total = time.time()
+ pbar_total = (
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
+ if not kwargs.get("disable_pbar", False)
+ else None
+ )
+
+ for i in range(len(vad_res)):
+ key = vad_res[i]["key"]
+ vadsegments = vad_res[i]["value"]
+ input_i = data_list[i]
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
+ speech = load_audio_text_image_video(
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
+ )
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
+ )
+
+ if kwargs["device"] == "cpu":
+ batch_size = 0
+
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths / 16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
+
+ all_segments = []
+ max_len_in_batch = 0
+ end_idx = 1
+
+ for j, _ in enumerate(range(0, n)):
+ # pbar_sample.update(1)
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
+ j + 1 - beg_idx
+ )
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
+ if (
+ j < n - 1
+ and sample_length < batch_size_threshold_ms
+ and potential_batch_length < batch_size
+ ):
+ max_len_in_batch = max(max_len_in_batch, sample_length)
+ end_idx += 1
+ continue
+
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
+ )
+ results = self.inference(
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
+ )
+
+ for _b in range(len(speech_j)):
+ results[_b]["interval"] = intervals[_b]
+
+ if self.spk_model is not None:
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
+ for _b in range(len(speech_j)):
+ vad_segments = [
+ [
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
+ np.array(speech_j[_b]),
+ ]
+ ]
+ segments = sv_chunk(vad_segments)
+ all_segments.extend(segments)
+ speech_b = [i[2] for i in segments]
+ spk_res = self.inference(
+ speech_b,
+ input_len=None,
+ model=self.spk_model,
+ kwargs=kwargs,
+ **cfg,
+ )
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
+
+ beg_idx = end_idx
+ end_idx += 1
+ max_len_in_batch = sample_length
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+ # end_asr_total = time.time()
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
+ # pbar_sample.update(1)
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ cur = results_sorted[j]
+ pattern = r"<\|([^|]+)\|>"
+ emotion_string = re.findall(pattern, cur["text"])
+ cur["text"] = re.sub(pattern, "", cur["text"])
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
+ deep_update(self.punc_kwargs, cfg)
+ punc_res = self.inference(
+ cur["text"],
+ model=self.punc_model,
+ kwargs=self.punc_kwargs,
+ **cfg,
+ )
+ cur["text"] = punc_res[0]["text"]
+
+ restored_data[index] = cur
+
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ if pbar_total:
+ pbar_total.update(1)
+ pbar_total.set_description(
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
+ )
+
+ # end_total = time.time()
+ # time_escape_total_all_samples = end_total - beg_total
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
+ return restored_data
+
+ def export(self, input=None, **cfg):
+ """
+
+ :param input:
+ :param type:
+ :param quantize:
+ :param fallback_num:
+ :param calib_num:
+ :param opset_version:
+ :param cfg:
+ :return:
+ """
+
+ device = cfg.get("device", "cpu")
+ model = self.model.to(device=device)
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ kwargs["device"] = device
+ del kwargs["model"]
+ model.eval()
+
+ type = kwargs.get("type", "onnx")
+
+ key_list, data_list = prepare_data_iterator(
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
+ )
+
+ with torch.no_grad():
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
+
+ return export_dir
diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..6789316d5186db69c021758094649553c3638f66
--- /dev/null
+++ b/tools/sensevoice/fun_asr.py
@@ -0,0 +1,332 @@
+import gc
+import os
+import re
+
+from audio_separator.separator import Separator
+
+os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
+os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
+import json
+import subprocess
+from pathlib import Path
+
+import click
+import torch
+from loguru import logger
+from pydub import AudioSegment
+from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
+from tools.sensevoice.auto_model import AutoModel
+
+
+def uvr5_cli(
+ audio_dir: Path,
+ output_folder: Path,
+ audio_files: list[Path] | None = None,
+ output_format: str = "flac",
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
+):
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
+ sepr = Separator(
+ model_file_dir=os.environ["UVR5_CACHE"],
+ output_dir=output_folder,
+ output_format=output_format,
+ )
+ dictmodel = {
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
+ }
+ roformer_model = dictmodel[model]
+ sepr.load_model(roformer_model)
+ if audio_files is None:
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+ total_files = len(audio_files)
+
+ print(f"{total_files} audio files found")
+
+ res = []
+ for audio in tqdm(audio_files, desc="Denoising: "):
+ file_path = str(audio_dir / audio)
+ sep_out = sepr.separate(file_path)
+ if isinstance(sep_out, str):
+ res.append(sep_out)
+ elif isinstance(sep_out, list):
+ res.extend(sep_out)
+ del sepr
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return res, roformer_model
+
+
+def get_sample_rate(media_path: Path):
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "quiet",
+ "-print_format",
+ "json",
+ "-show_streams",
+ str(media_path),
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+ media_info = json.loads(result.stdout)
+ for stream in media_info.get("streams", []):
+ if stream.get("codec_type") == "audio":
+ return stream.get("sample_rate")
+ return "44100" # Default sample rate if not found
+
+
+def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
+ sr = get_sample_rate(src_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.resolve() == out_path.resolve():
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
+ else:
+ output = str(out_path)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-loglevel",
+ "error",
+ "-i",
+ str(src_path),
+ "-acodec",
+ "pcm_s16le" if out_fmt == "wav" else "flac",
+ "-ar",
+ sr,
+ "-ac",
+ "1",
+ "-y",
+ output,
+ ],
+ check=True,
+ )
+ return out_path
+
+
+def convert_video_to_audio(video_path: Path, audio_dir: Path):
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
+ vocals = [
+ p
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
+ if p.suffix in AUDIO_EXTENSIONS
+ ]
+ if len(vocals) > 0:
+ return vocals[0]
+ audio_path = cur_dir / f"{video_path.stem}.wav"
+ convert_to_mono(video_path, audio_path)
+ return audio_path
+
+
+@click.command()
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option(
+ "--max_single_segment_time",
+ default=20000,
+ type=int,
+ help="Maximum of Output single audio duration(ms)",
+)
+@click.option("--fsmn-vad/--silero-vad", default=False)
+@click.option("--punc/--no-punc", default=False)
+@click.option("--denoise/--no-denoise", default=False)
+@click.option("--save_emo/--no_save_emo", default=False)
+def main(
+ audio_dir: str,
+ save_dir: str,
+ device: str,
+ language: str,
+ max_single_segment_time: int,
+ fsmn_vad: bool,
+ punc: bool,
+ denoise: bool,
+ save_emo: bool,
+):
+
+ audios_path = Path(audio_dir)
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ video_files = list_files(
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
+ )
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
+
+ if denoise:
+ VOCAL = "_(Vocals)"
+ original_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
+ ]
+
+ _, cur_model = uvr5_cli(
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
+ )
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
+ need_remove.extend(original_files)
+ for _ in need_remove:
+ _.unlink()
+ vocal_files = [
+ p
+ for p in audios_path.glob("**/*")
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
+ ]
+ for f in vocal_files:
+ fn, ext = f.stem, f.suffix
+
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
+ if v_pos != -1:
+ new_fn = fn[: v_pos + len(VOCAL)]
+ new_f = f.with_name(new_fn + ext)
+ f = f.rename(new_f)
+ convert_to_mono(f, f, "flac")
+ f.unlink()
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ logger.info("Loading / Downloading Funasr model...")
+
+ model_dir = "iic/SenseVoiceSmall"
+
+ vad_model = "fsmn-vad" if fsmn_vad else None
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
+ punc_model = "ct-punc" if punc else None
+
+ manager = AutoModel(
+ model=model_dir,
+ trust_remote_code=False,
+ vad_model=vad_model,
+ vad_kwargs=vad_kwargs,
+ punc_model=punc_model,
+ device=device,
+ )
+
+ if not fsmn_vad and vad_model is None:
+ vad_model = load_silero_vad()
+
+ logger.info("Model loaded.")
+
+ pattern = re.compile(r"_\d{3}\.")
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+
+ if pattern.search(file_path.name):
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
+ continue
+
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ cfg = dict(
+ cache={},
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ batch_size_s=60,
+ )
+
+ if fsmn_vad:
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
+ else:
+ wav = read_audio(
+ str(file_path)
+ ) # backend (sox, soundfile, or ffmpeg) required!
+ audio_key = file_path.stem
+ audio_val = []
+ speech_timestamps = get_speech_timestamps(
+ wav,
+ vad_model,
+ max_speech_duration_s=max_single_segment_time // 1000,
+ return_seconds=True,
+ )
+
+ audio_val = [
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
+ for timestamp in speech_timestamps
+ ]
+ vad_res = []
+ vad_res.append(dict(key=audio_key, value=audio_val))
+
+ res = manager.inference_with_vadres(
+ input=str(file_path), vad_res=vad_res, **cfg
+ )
+
+ for i, info in enumerate(res):
+ [start_ms, end_ms] = info["interval"]
+ text = info["text"]
+ emo = info["emo"]
+ sliced_audio = audio[start_ms:end_ms]
+ audio_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
+ )
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}: {text}")
+
+ transcript_save_path = (
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
+ )
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(text)
+
+ if save_emo:
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
+ with open(
+ emo_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(emo)
+
+ if audios_path.resolve() == save_path.resolve():
+ file_path.unlink()
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+ # Load the audio file
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
+ model_dir = "iic/SenseVoiceSmall"
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
+ m.eval()
+
+ res = m.inference(
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
+ use_itn=False,
+ ban_emo_unk=False,
+ **kwargs,
+ )
+
+ print(res)
+ text = rich_transcription_postprocess(res[0][0]["text"])
+ print(text)
diff --git a/tools/sensevoice/vad_utils.py b/tools/sensevoice/vad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bef75ed8c2841701fff44f7130e91ef8dfdf8cc
--- /dev/null
+++ b/tools/sensevoice/vad_utils.py
@@ -0,0 +1,61 @@
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
+ speech_i = speech[0, bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+ return feats_pad, speech_lengths_pad
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ intervals = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx:end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+ intervals.append([bed_idx // 16, end_idx // 16])
+
+ return speech_list, speech_lengths_list, intervals
+
+
+def merge_vad(vad_result, max_length=15000, min_length=0):
+ new_result = []
+ if len(vad_result) <= 1:
+ return vad_result
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
+ time_step = sorted(list(set(time_step)))
+ if len(time_step) == 0:
+ return []
+ bg = 0
+ for i in range(len(time_step) - 1):
+ time = time_step[i]
+ if time_step[i + 1] - bg < max_length:
+ continue
+ if time - bg > min_length:
+ new_result.append([bg, time])
+ # if time - bg < max_length * 1.5:
+ # new_result.append([bg, time])
+ # else:
+ # split_num = int(time - bg) // max_length + 1
+ # spl_l = int(time - bg) // split_num
+ # for j in range(split_num):
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
+ bg = time
+ new_result.append([bg, time_step[-1]])
+ return new_result
diff --git a/tools/smart_pad.py b/tools/smart_pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..de9dc154f26b2869a7e34f7d4cd95db741ee4c6a
--- /dev/null
+++ b/tools/smart_pad.py
@@ -0,0 +1,60 @@
+import random
+from multiprocessing import Pool
+from pathlib import Path
+
+import click
+import librosa
+import torch.nn.functional as F
+import torchaudio
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+threshold = 10 ** (-50 / 20.0)
+
+
+def process(file):
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
+ if waveform.size(0) > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+
+ loudness = librosa.feature.rms(
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
+ )[0]
+
+ for i in range(len(loudness) - 1, 0, -1):
+ if loudness[i] > threshold:
+ break
+
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
+
+ if end_silent_time <= 0.3:
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
+ waveform = F.pad(
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
+ )
+
+ for i in range(len(loudness)):
+ if loudness[i] > threshold:
+ break
+
+ start_silent_time = i * 512 / sample_rate
+
+ if start_silent_time > 0.02:
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
+
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
+
+
+@click.command()
+@click.argument("source", type=Path)
+@click.option("--num-workers", type=int, default=12)
+def main(source, num_workers):
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
+
+ with Pool(num_workers) as p:
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/__pycache__/inference.cpython-310.pyc b/tools/vqgan/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4654556683db8aaa2c482d8adfb4294f5cd59e8
Binary files /dev/null and b/tools/vqgan/__pycache__/inference.cpython-310.pyc differ
diff --git a/tools/vqgan/create_train_split.py b/tools/vqgan/create_train_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..d24a5f39566c47ea0cb1fc506d463e9c95c3efbc
--- /dev/null
+++ b/tools/vqgan/create_train_split.py
@@ -0,0 +1,83 @@
+import math
+from pathlib import Path
+from random import Random
+
+import click
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+
+@click.command()
+@click.argument("root", type=click.Path(exists=True, path_type=Path))
+@click.option("--val-ratio", type=float, default=None)
+@click.option("--val-count", type=int, default=None)
+@click.option("--filelist", default=None, type=Path)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
+ if min_duration is None and max_duration is None:
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
+ else:
+ filtered_files = []
+ for file in tqdm(files):
+ try:
+ audio = AudioSegment.from_file(str(file))
+ duration = len(audio) / 1000.0
+
+ if min_duration is not None and duration < min_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+ )
+ continue
+
+ if max_duration is not None and duration > max_duration:
+ logger.info(
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+ )
+ continue
+
+ filtered_files.append(str(file.relative_to(root)))
+ except Exception as e:
+ logger.info(f"Error processing {file}: {e}")
+
+ logger.info(
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+ )
+
+ Random(42).shuffle(filtered_files)
+
+ if val_count is None and val_ratio is None:
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
+ elif val_count is not None and val_ratio is not None:
+ logger.error("Cannot specify both val_count and val_ratio")
+ return
+ elif val_count is not None:
+ if val_count < 1 or val_count > len(filtered_files):
+ logger.error("val_count must be between 1 and number of files")
+ return
+ val_size = val_count
+ else:
+ val_size = math.ceil(len(filtered_files) * val_ratio)
+
+ logger.info(f"Using {val_size} files for validation")
+
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[val_size:]))
+
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+ f.write("\n".join(filtered_files[:val_size]))
+
+ logger.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/extract_vq.py b/tools/vqgan/extract_vq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c24eb3f46ab57fb02930f233a67299cb31c7d7ba
--- /dev/null
+++ b/tools/vqgan/extract_vq.py
@@ -0,0 +1,227 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} | "
+ "{level: <8} | "
+ "{name}:{function}:{line} | "
+ "{extra[rank]} - {message}"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model(
+ config_name: str = "firefly_gan_vq",
+ checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ device: str | torch.device = "cuda",
+):
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model")
+ return model
+
+
+@torch.inference_mode()
+def process_batch(files: list[Path], model) -> float:
+ wavs = []
+ audio_lengths = []
+ new_files = []
+ max_length = total_time = 0
+
+ for file in files:
+ try:
+ wav, sr = torchaudio.load(
+ str(file), backend="sox" if sys.platform == "linux" else "soundfile"
+ ) # Need to install libsox-dev
+ except Exception as e:
+ logger.error(f"Error reading {file}: {e}")
+ continue
+
+ if wav.shape[0] > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+
+ wav = torchaudio.functional.resample(
+ wav.cuda(), sr, model.spec_transform.sample_rate
+ )[0]
+ total_time += len(wav) / model.spec_transform.sample_rate
+ max_length = max(max_length, len(wav))
+
+ wavs.append(wav)
+ audio_lengths.append(len(wav))
+ new_files.append(file)
+
+ files = new_files
+
+ # Pad to max length
+ for i, wav in enumerate(wavs):
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+ audios = torch.stack(wavs, dim=0)[:, None]
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+ # Calculate lengths
+ indices, feature_lengths = model.encode(audios, audio_lengths)
+
+ # Save to disk
+ outputs = indices.cpu().numpy()
+
+ for file, length, feature, audio_length in zip(
+ files, feature_lengths, outputs, audio_lengths
+ ):
+ feature = feature[:, :length]
+
+ # (T,)
+ with open(file.with_suffix(".npy"), "wb") as f:
+ np.save(f, feature)
+
+ return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
+def main(
+ folder: str,
+ num_workers: int,
+ config_name: str,
+ checkpoint_path: str,
+ batch_size: int,
+ filelist: Path,
+):
+ if num_workers > 1 and WORLD_SIZE != num_workers:
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+ logger.info(f"Spawning {num_workers} workers")
+
+ if torch.cuda.is_available():
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if visible_devices is None:
+ visible_devices = list(range(torch.cuda.device_count()))
+ else:
+ visible_devices = visible_devices.split(",")
+ else:
+ # Set to empty string to avoid using GPU
+ visible_devices = [""]
+
+ processes = []
+ for i in range(num_workers):
+ env = os.environ.copy()
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+ env["SLURM_PROCID"] = str(i)
+ env["SLURM_NTASKS"] = str(num_workers)
+
+ processes.append(
+ sp.Popen(
+ [sys.executable] + sys.argv.copy(),
+ env=env,
+ )
+ )
+
+ for p in processes:
+ p.wait()
+
+ logger.info(f"All workers finished")
+ return
+
+ # This is a worker
+ logger.info(f"Starting worker")
+ if filelist:
+ files = [i[0] for i in load_filelist(filelist)]
+ else:
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
+
+ print(f"Found {len(files)} files")
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
+
+ total_files = len(files)
+ files = files[RANK::WORLD_SIZE]
+ logger.info(f"Processing {len(files)}/{total_files} files")
+
+ # Batch processing
+ total_time = 0
+ begin_time = time.time()
+ processed_files = 0
+ model = get_model(config_name, checkpoint_path)
+
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
+ batch = files[idx : idx + batch_size]
+ batch_time = process_batch(batch, model)
+
+ total_time += batch_time
+ processed_files += len(batch)
+
+ if (n_batch + 1) % 10 == 0:
+ eta = (
+ (time.time() - begin_time)
+ / processed_files
+ * (len(files) - processed_files)
+ )
+ logger.info(
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+ + f"ETA: {timedelta(seconds=round(eta))}s"
+ )
+
+ logger.info(
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/vqgan/inference.py b/tools/vqgan/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bc7531c41455c346109bdaaa43dafc1e3508a4
--- /dev/null
+++ b/tools/vqgan/inference.py
@@ -0,0 +1,122 @@
+from pathlib import Path
+
+import click
+import hydra
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tools.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+ cfg = compose(config_name=config_name)
+
+ model = instantiate(cfg)
+ state_dict = torch.load(
+ checkpoint_path,
+ map_location=device,
+ )
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+
+ if any("generator" in k for k in state_dict):
+ state_dict = {
+ k.replace("generator.", ""): v
+ for k, v in state_dict.items()
+ if "generator." in k
+ }
+
+ result = model.load_state_dict(state_dict, strict=False)
+ model.eval()
+ model.to(device)
+
+ logger.info(f"Loaded model: {result}")
+ return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+ "--input-path",
+ "-i",
+ default="test.wav",
+ type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", default="firefly_gan_vq")
+@click.option(
+ "--checkpoint-path",
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+)
+@click.option(
+ "--device",
+ "-d",
+ default="cuda",
+)
+def main(input_path, output_path, config_name, checkpoint_path, device):
+ model = load_model(config_name, checkpoint_path, device=device)
+
+ if input_path.suffix in AUDIO_EXTENSIONS:
+ logger.info(f"Processing in-place reconstruction of {input_path}")
+
+ # Load audio
+ audio, sr = torchaudio.load(str(input_path))
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(
+ audio, sr, model.spec_transform.sample_rate
+ )
+
+ audios = audio[None].to(device)
+ logger.info(
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
+ )
+
+ # VQ Encoder
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
+ indices = model.encode(audios, audio_lengths)[0][0]
+
+ logger.info(f"Generated indices of shape {indices.shape}")
+
+ # Save indices
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+ elif input_path.suffix == ".npy":
+ logger.info(f"Processing precomputed indices from {input_path}")
+ indices = np.load(input_path)
+ indices = torch.from_numpy(indices).to(device).long()
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+ else:
+ raise ValueError(f"Unknown input type: {input_path}")
+
+ # Restore
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
+ fake_audios, _ = model.decode(
+ indices=indices[None], feature_lengths=feature_lengths
+ )
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
+
+ logger.info(
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
+ )
+
+ # Save audio
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
+ logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/webui.py b/tools/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff155d48967b4d3980e280096cafc511009a737
--- /dev/null
+++ b/tools/webui.py
@@ -0,0 +1,485 @@
+import gc
+import html
+import io
+import os
+import queue
+import wave
+from argparse import ArgumentParser
+from functools import partial
+from pathlib import Path
+
+import gradio as gr
+import librosa
+import numpy as np
+import pyrootutils
+import torch
+from loguru import logger
+from transformers import AutoTokenizer
+
+pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
+
+
+from fish_speech.i18n import i18n
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
+from fish_speech.utils import autocast_exclude_mps
+from tools.api import decode_vq_tokens, encode_reference
+from tools.llama.generate import (
+ GenerateRequest,
+ GenerateResponse,
+ WrappedGenerateResponse,
+ launch_thread_safe_queue,
+)
+from tools.vqgan.inference import load_model as load_decoder_model
+
+# Make einx happy
+os.environ["EINX_FILTER_TRACEBACK"] = "false"
+
+
+HEADER_MD = f"""# Fish Speech
+
+{i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
+
+{i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
+
+{i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
+
+{i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
+"""
+
+TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
+SPACE_IMPORTED = False
+
+
+def build_html_error_message(error):
+ return f"""
+
+ {html.escape(str(error))}
+
+ """
+
+
+@torch.inference_mode()
+def inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ streaming=False,
+):
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
+ return (
+ None,
+ None,
+ i18n("Text is too long, please keep it under {} characters.").format(
+ args.max_gradio_length
+ ),
+ )
+
+ # Parse reference audio aka prompt
+ prompt_tokens = encode_reference(
+ decoder_model=decoder_model,
+ reference_audio=reference_audio,
+ enable_reference_audio=enable_reference_audio,
+ )
+
+ # LLAMA Inference
+ request = dict(
+ device=decoder_model.device,
+ max_new_tokens=max_new_tokens,
+ text=text,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ temperature=temperature,
+ compile=args.compile,
+ iterative_prompt=chunk_length > 0,
+ chunk_length=chunk_length,
+ max_length=2048,
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
+ prompt_text=reference_text if enable_reference_audio else None,
+ )
+
+ response_queue = queue.Queue()
+ llama_queue.put(
+ GenerateRequest(
+ request=request,
+ response_queue=response_queue,
+ )
+ )
+
+ if streaming:
+ yield wav_chunk_header(), None, None
+
+ segments = []
+
+ while True:
+ result: WrappedGenerateResponse = response_queue.get()
+ if result.status == "error":
+ yield None, None, build_html_error_message(result.response)
+ break
+
+ result: GenerateResponse = result.response
+ if result.action == "next":
+ break
+
+ with autocast_exclude_mps(
+ device_type=decoder_model.device.type, dtype=args.precision
+ ):
+ fake_audios = decode_vq_tokens(
+ decoder_model=decoder_model,
+ codes=result.codes,
+ )
+
+ fake_audios = fake_audios.float().cpu().numpy()
+ segments.append(fake_audios)
+
+ if streaming:
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
+
+ if len(segments) == 0:
+ return (
+ None,
+ None,
+ build_html_error_message(
+ i18n("No audio generated, please check the input text.")
+ ),
+ )
+
+ # No matter streaming or not, we need to return the final audio
+ audio = np.concatenate(segments, axis=0)
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+inference_stream = partial(inference, streaming=True)
+
+n_audios = 4
+
+global_audio_list = []
+global_error_list = []
+
+
+def inference_wrapper(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+):
+ audios = []
+ errors = []
+
+ for _ in range(batch_infer_num):
+ result = inference(
+ text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ )
+
+ _, audio_data, error_message = next(result)
+
+ audios.append(
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
+ )
+ errors.append(
+ gr.HTML(value=error_message if error_message else None, visible=True),
+ )
+
+ for _ in range(batch_infer_num, n_audios):
+ audios.append(
+ gr.Audio(value=None, visible=False),
+ )
+ errors.append(
+ gr.HTML(value=None, visible=False),
+ )
+
+ return None, *audios, *errors
+
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+ buffer = io.BytesIO()
+
+ with wave.open(buffer, "wb") as wav_file:
+ wav_file.setnchannels(channels)
+ wav_file.setsampwidth(bit_depth // 8)
+ wav_file.setframerate(sample_rate)
+
+ wav_header_bytes = buffer.getvalue()
+ buffer.close()
+ return wav_header_bytes
+
+
+def normalize_text(user_input, use_normalization):
+ if use_normalization:
+ return ChnNormedText(raw_text=user_input).normalize()
+ else:
+ return user_input
+
+
+asr_model = None
+
+
+def build_app():
+ with gr.Blocks(theme=gr.themes.Base()) as app:
+ gr.Markdown(HEADER_MD)
+
+ # Use light theme by default
+ app.load(
+ None,
+ None,
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
+ % args.theme,
+ )
+
+ # Inference
+ with gr.Row():
+ with gr.Column(scale=3):
+ text = gr.Textbox(
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
+ )
+ refined_text = gr.Textbox(
+ label=i18n("Realtime Transform Text"),
+ placeholder=i18n(
+ "Normalization Result Preview (Currently Only Chinese)"
+ ),
+ lines=5,
+ interactive=False,
+ )
+
+ with gr.Row():
+ if_refine_text = gr.Checkbox(
+ label=i18n("Text Normalization"),
+ value=False,
+ scale=1,
+ )
+
+ with gr.Row():
+ with gr.Tab(label=i18n("Advanced Config")):
+ chunk_length = gr.Slider(
+ label=i18n("Iterative Prompt Length, 0 means off"),
+ minimum=50,
+ maximum=300,
+ value=200,
+ step=8,
+ )
+
+ max_new_tokens = gr.Slider(
+ label=i18n("Maximum tokens per batch, 0 means no limit"),
+ minimum=0,
+ maximum=2048,
+ value=1024, # 0 means no limit
+ step=8,
+ )
+
+ top_p = gr.Slider(
+ label="Top-P",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ repetition_penalty = gr.Slider(
+ label=i18n("Repetition Penalty"),
+ minimum=1,
+ maximum=1.5,
+ value=1.2,
+ step=0.01,
+ )
+
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0.6,
+ maximum=0.9,
+ value=0.7,
+ step=0.01,
+ )
+
+ with gr.Tab(label=i18n("Reference Audio")):
+ gr.Markdown(
+ i18n(
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
+ )
+ )
+
+ enable_reference_audio = gr.Checkbox(
+ label=i18n("Enable Reference Audio"),
+ )
+ reference_audio = gr.Audio(
+ label=i18n("Reference Audio"),
+ type="filepath",
+ )
+ with gr.Row():
+ reference_text = gr.Textbox(
+ label=i18n("Reference Text"),
+ lines=1,
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+ value="",
+ )
+ with gr.Tab(label=i18n("Batch Inference")):
+ batch_infer_num = gr.Slider(
+ label="Batch infer nums",
+ minimum=1,
+ maximum=n_audios,
+ step=1,
+ value=1,
+ )
+
+ with gr.Column(scale=3):
+ for _ in range(n_audios):
+ with gr.Row():
+ error = gr.HTML(
+ label=i18n("Error Message"),
+ visible=True if _ == 0 else False,
+ )
+ global_error_list.append(error)
+ with gr.Row():
+ audio = gr.Audio(
+ label=i18n("Generated Audio"),
+ type="numpy",
+ interactive=False,
+ visible=True if _ == 0 else False,
+ )
+ global_audio_list.append(audio)
+
+ with gr.Row():
+ stream_audio = gr.Audio(
+ label=i18n("Streaming Audio"),
+ streaming=True,
+ autoplay=True,
+ interactive=False,
+ show_download_button=True,
+ )
+ with gr.Row():
+ with gr.Column(scale=3):
+ generate = gr.Button(
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
+ )
+ generate_stream = gr.Button(
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
+ variant="primary",
+ )
+
+ text.input(
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
+ )
+
+ # # Submit
+ generate.click(
+ inference_wrapper,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ batch_infer_num,
+ ],
+ [stream_audio, *global_audio_list, *global_error_list],
+ concurrency_limit=1,
+ )
+
+ generate_stream.click(
+ inference_stream,
+ [
+ refined_text,
+ enable_reference_audio,
+ reference_audio,
+ reference_text,
+ max_new_tokens,
+ chunk_length,
+ top_p,
+ repetition_penalty,
+ temperature,
+ ],
+ [stream_audio, global_audio_list[0], global_error_list[0]],
+ concurrency_limit=10,
+ )
+ return app
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--llama-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4",
+ )
+ parser.add_argument(
+ "--decoder-checkpoint-path",
+ type=Path,
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+ )
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--half", action="store_true")
+ parser.add_argument("--compile", action="store_true")
+ parser.add_argument("--max-gradio-length", type=int, default=0)
+ parser.add_argument("--theme", type=str, default="light")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ args.precision = torch.half if args.half else torch.bfloat16
+
+ logger.info("Loading Llama model...")
+ llama_queue = launch_thread_safe_queue(
+ checkpoint_path=args.llama_checkpoint_path,
+ device=args.device,
+ precision=args.precision,
+ compile=args.compile,
+ )
+ logger.info("Llama model loaded, loading VQ-GAN model...")
+
+ decoder_model = load_decoder_model(
+ config_name=args.decoder_config_name,
+ checkpoint_path=args.decoder_checkpoint_path,
+ device=args.device,
+ )
+
+ logger.info("Decoder model loaded, warming up...")
+
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
+ list(
+ inference(
+ text="Hello, world!",
+ enable_reference_audio=False,
+ reference_audio=None,
+ reference_text="",
+ max_new_tokens=1024,
+ chunk_length=200,
+ top_p=0.7,
+ repetition_penalty=1.2,
+ temperature=0.7,
+ )
+ )
+
+ logger.info("Warming up done, launching the web UI...")
+
+ app = build_app()
+ app.launch(show_api=True)
diff --git a/tools/whisper_asr.py b/tools/whisper_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..42e7de8a185880d3f2afd368d6df3429488465a4
--- /dev/null
+++ b/tools/whisper_asr.py
@@ -0,0 +1,176 @@
+"""
+Used to transcribe all audio files in one folder into another folder.
+e.g.
+Directory structure:
+--pre_data_root
+----SP_1
+------01.wav
+------02.wav
+------......
+----SP_2
+------01.wav
+------02.wav
+------......
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
+to transcribe the first speaker.
+
+Use
+python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
+to transcribe the second speaker.
+
+Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
+"""
+
+import re
+from pathlib import Path
+
+import click
+import soundfile as sf
+from faster_whisper import WhisperModel
+from loguru import logger
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from tools.file import AUDIO_EXTENSIONS, list_files
+
+
+@click.command()
+@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
+@click.option(
+ "--compute-type",
+ default="float16",
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
+)
+@click.option("--audio-dir", required=True, help="Directory containing audio files")
+@click.option(
+ "--save-dir", required=True, help="Directory to save processed audio files"
+)
+@click.option(
+ "--sample-rate",
+ default=44100,
+ type=int,
+ help="Output sample rate, default to input sample rate",
+)
+@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
+@click.option("--language", default="auto", help="Language of the transcription")
+@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
+def main(
+ model_size,
+ compute_type,
+ audio_dir,
+ save_dir,
+ sample_rate,
+ device,
+ language,
+ initial_prompt,
+):
+ logger.info("Loading / Downloading Faster Whisper model...")
+
+ model = WhisperModel(
+ model_size,
+ device=device,
+ compute_type=compute_type,
+ download_root="faster_whisper",
+ )
+
+ logger.info("Model loaded.")
+
+ save_path = Path(save_dir)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ audio_files = list_files(
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
+ )
+
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
+ file_stem = file_path.stem
+ file_suffix = file_path.suffix
+
+ rel_path = Path(file_path).relative_to(audio_dir)
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
+
+ audio = AudioSegment.from_file(file_path)
+
+ segments, info = model.transcribe(
+ file_path,
+ beam_size=5,
+ language=None if language == "auto" else language,
+ initial_prompt=initial_prompt,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ whole_text = None
+ for segment in segments:
+ id, start, end, text = (
+ segment.id,
+ segment.start,
+ segment.end,
+ segment.text,
+ )
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
+ if not whole_text:
+ whole_text = text
+ else:
+ whole_text += ", " + text
+
+ whole_text += "."
+
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
+ audio.export(audio_save_path, format=file_suffix[1:])
+ print(f"Exported {audio_save_path}")
+
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
+ with open(
+ transcript_save_path,
+ "w",
+ encoding="utf-8",
+ ) as f:
+ f.write(whole_text)
+
+
+if __name__ == "__main__":
+ main()
+ exit(0)
+
+ audio = AudioSegment.from_wav(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
+ )
+
+ model_size = "large-v3"
+
+ model = WhisperModel(
+ model_size,
+ device="cuda",
+ compute_type="float16",
+ download_root="faster_whisper",
+ )
+
+ segments, info = model.transcribe(
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
+ beam_size=5,
+ )
+
+ print(
+ "Detected language '%s' with probability %f"
+ % (info.language, info.language_probability)
+ )
+ print("Total len(ms): ", len(audio))
+
+ for i, segment in enumerate(segments):
+ print(
+ "Segment %03d [%.2fs -> %.2fs] %s"
+ % (i, segment.start, segment.end, segment.text)
+ )
+ start_ms = int(segment.start * 1000)
+ end_ms = int(segment.end * 1000)
+ segment_audio = audio[start_ms:end_ms]
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
+ print(f"Exported segment_{i:03d}.wav")
+
+ print("All segments have been exported.")