import abc import functools import io import json import logging import os import tarfile import typing import torch.utils.data import torchaudio import transformers import vocos from torchvision.datasets.utils import download_url from modules.ChatTTS.ChatTTS.utils.infer_utils import ( apply_character_map, count_invalid_characters, ) class LazyDataType(typing.TypedDict): filepath: str speaker: str lang: str text: str class DataType(LazyDataType): text_input_ids: torch.Tensor # (batch_size, text_len) text_attention_mask: torch.Tensor # (batch_size, text_len) audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100) audio_attention_mask: torch.Tensor # (batch_size, audio_len) class XzListTarKwargsType(typing.TypedDict): tokenizer: typing.Union[transformers.PreTrainedTokenizer, None] vocos_model: typing.Union[vocos.Vocos, None] device: typing.Union[str, torch.device, None] speakers: typing.Union[typing.Iterable[str], None] sample_rate: typing.Union[int] default_speaker: typing.Union[str, None] default_lang: typing.Union[str, None] tar_in_memory: typing.Union[bool, None] process_ahead: typing.Union[bool, None] class AudioFolder(torch.utils.data.Dataset, abc.ABC): def __init__( self, root: str | io.BytesIO, tokenizer: transformers.PreTrainedTokenizer | None = None, vocos_model: vocos.Vocos | None = None, device: str | torch.device | None = None, speakers: typing.Iterable[str] | None = None, sample_rate: int = 24_000, default_speaker: str | None = None, default_lang: str | None = None, tar_path: str | None = None, tar_in_memory: bool = False, process_ahead: bool = False, ) -> None: self.root = root self.sample_rate = sample_rate self.default_speaker = default_speaker self.default_lang = default_lang self.logger = logging.getLogger(__name__) self.normalizer = {} self.tokenizer = tokenizer self.vocos = vocos_model self.vocos_device = ( None if self.vocos is None else next(self.vocos.parameters()).device ) self.device = device or self.vocos_device # tar -cvf ../Xz.tar * # tar -xf Xz.tar -C ./Xz self.tar_path = tar_path self.tar_file = None self.tar_io = None if tar_path is not None: if tar_in_memory: with open(tar_path, "rb") as f: self.tar_io = io.BytesIO(f.read()) self.tar_file = tarfile.open(fileobj=self.tar_io) else: self.tar_file = tarfile.open(tar_path) self.lazy_data, self.speakers = self.get_lazy_data(root, speakers) self.text_input_ids: dict[int, torch.Tensor] = {} self.audio_mel_specs: dict[int, torch.Tensor] = {} if process_ahead: for n, item in enumerate(self.lazy_data): self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"]) self.text_input_ids[n] = self.preprocess_text( item["text"], item["lang"] ) if self.tar_file is not None: self.tar_file.close() if self.tar_io is not None: self.tar_io.close() @abc.abstractmethod def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ... @staticmethod @abc.abstractmethod def save_config( save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" ) -> None: ... def __len__(self): return len(self.lazy_data) def __getitem__(self, n: int) -> DataType: lazy_data = self.lazy_data[n] if n in self.audio_mel_specs: audio_mel_specs = self.audio_mel_specs[n] text_input_ids = self.text_input_ids[n] else: audio_mel_specs = self.preprocess_audio(lazy_data["filepath"]) text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"]) self.audio_mel_specs[n] = audio_mel_specs self.text_input_ids[n] = text_input_ids if len(self.audio_mel_specs) == len(self.lazy_data): if self.tar_file is not None: self.tar_file.close() if self.tar_io is not None: self.tar_io.close() text_attention_mask = torch.ones( len(text_input_ids), device=text_input_ids.device ) audio_attention_mask = torch.ones( (len(audio_mel_specs) + 1) // 2, device=audio_mel_specs.device, ) return { "filepath": lazy_data["filepath"], "speaker": lazy_data["speaker"], "lang": lazy_data["lang"], "text": lazy_data["text"], "text_input_ids": text_input_ids, "text_attention_mask": text_attention_mask, "audio_mel_specs": audio_mel_specs, "audio_attention_mask": audio_attention_mask, } def get_lazy_data( self, root: str | io.BytesIO, speakers: typing.Iterable[str] | None = None, ) -> tuple[list[LazyDataType], set[str]]: if speakers is not None: new_speakers = set(speakers) else: new_speakers = set() lazy_data = [] raw_data = self.get_raw_data(root) folder_path = os.path.dirname(root) if isinstance(root, str) else "" for item in raw_data: if "speaker" not in item: item["speaker"] = self.default_speaker if "lang" not in item: item["lang"] = self.default_lang if speakers is not None and item["speaker"] not in speakers: continue if speakers is None and item["speaker"] not in new_speakers: new_speakers.add(item["speaker"]) if self.tar_file is None and isinstance(root, str): filepath = os.path.join(folder_path, item["filepath"]) else: filepath = item["filepath"] lazy_data.append( { "filepath": filepath, "speaker": item["speaker"], "lang": item["lang"].lower(), "text": item["text"], } ) return lazy_data, new_speakers def preprocess_text( self, text: str, lang: str, ) -> torch.Tensor: invalid_characters = count_invalid_characters(text) if len(invalid_characters): # self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') text = apply_character_map(text) # if not skip_refine_text: # text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] # text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] # text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) # if refine_text_only: # return text text = f"[Stts][spk_emb]{text}[Ptts]" # text = f'[Stts][empty_spk]{text}[Ptts]' text_token = self.tokenizer( text, return_tensors="pt", add_special_tokens=False ).to(device=self.device) return text_token["input_ids"].squeeze(0) def preprocess_audio(self, filepath: str) -> torch.Tensor: if self.tar_file is not None: file = self.tar_file.extractfile(filepath) waveform, sample_rate = torchaudio.load(file) else: waveform, sample_rate = torchaudio.load(filepath) waveform = waveform.to(device=self.vocos_device) if sample_rate != self.sample_rate: waveform = torchaudio.functional.resample( waveform, orig_freq=sample_rate, new_freq=self.sample_rate, ) mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform) return ( mel_spec.to(device=self.device).squeeze(0).transpose(0, 1) ) # (audio_len*2, 100) class JsonFolder(AudioFolder): """ In json file, each item is formatted as following example: `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`. filepath is relative to the dirname of root json file. """ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: with open(root, "r", encoding="utf-8") as f: raw_data = json.load(f) return raw_data @staticmethod def save_config( save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" ) -> None: save_data = [item.copy() for item in lazy_data] for item in save_data: item["filepath"] = os.path.relpath(item["filepath"], rel_path) with open(save_path, "w", encoding="utf-8") as f: json.dump(save_data, f, ensure_ascii=False, indent=4) class ListFolder(AudioFolder): """ In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator. `path/to/file.wav|John|ZH|Hello`. filepath is relative to the dirname of root list file. """ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: raw_data = [] with open(root, "r", encoding="utf-8") as f: for line in f.readlines(): line = line.strip().removesuffix("\n") if len(line) == 0: continue filepath, speaker, lang, text = line.split(sep="|", maxsplit=3) raw_data.append( { "text": text, "filepath": filepath, "speaker": speaker, "lang": lang, } ) return raw_data @staticmethod def save_config( save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" ) -> None: save_data = [item.copy() for item in lazy_data] for item in save_data: item["filepath"] = os.path.relpath(item["filepath"], rel_path) with open(save_path, "w", encoding="utf-8") as f: for item in save_data: f.write( f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n" ) class XzListTar(ListFolder): def __init__( self, *args, root: str | io.BytesIO, tar_path: str | None = None, **kwargs, ): if isinstance(root, io.BytesIO): assert tar_path is not None else: # make sure root is a list file if not root.endswith(".list"): # folder case if os.path.isfile(root): raise FileExistsError(f"{root} is a file!") elif not os.path.exists(root): os.makedirs(root) root = os.path.join(root, "all.list") if isinstance(root, str) and not os.path.isfile(root): # prepare all.list self.concat_dataset( save_folder=os.path.dirname(root), langs=kwargs.get("langs", ["zh", "en"]), ) super().__init__(root, *args, tar_path=tar_path, **kwargs) def concat_dataset( self, save_folder: str | None = None, langs: list[str] = ["zh", "en"] ) -> None: if save_folder is None: save_folder = os.path.dirname(self.root) if os.path.isfile(save_folder): raise FileExistsError(f"{save_folder} already exists as a file!") elif not os.path.exists(save_folder): os.makedirs(save_folder) lazy_data = [] for member in self.tar_file.getmembers(): if not member.isfile(): continue if member.name.endswith(".list"): print(member.name) root_io = self.tar_file.extractfile(member) lazy_data += ListFolder(root_io).lazy_data if member.name.endswith(".json"): print(member.name) root_io = self.tar_file.extractfile(member) lazy_data += JsonFolder(root_io).lazy_data if langs is not None: lazy_data = [item for item in lazy_data if item["lang"] in langs] ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data) JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data) print(f"all.list and all.json are saved to {save_folder}") class XzListFolder(ListFolder): """ [Xz乔希](https://space.bilibili.com/5859321) Only look at the basename of filepath in list file. Previous folder paths are ignored. Files are organized as `[list basename]/[file basename]` Example tree structure: [folder] ├── speaker_A │ ├── 1.wav │ └── 2.wav ├── speaker_A.list ├── speaker_B │ ├── 1.wav │ └── 2.wav └── speaker_B.list """ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: raw_data = super().get_raw_data(root) for item in raw_data: item["filepath"] = os.path.join( os.path.basename(root).removesuffix(".list"), os.path.basename(item["filepath"]), ) return raw_data class AudioCollator: def __init__(self, text_pad: int = 0, audio_pad: int = 0): self.text_pad = text_pad self.audio_pad = audio_pad def __call__(self, batch: list[DataType]): batch = [x for x in batch if x is not None] audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch) text_maxlen = max(len(item["text_attention_mask"]) for item in batch) filepath = [] speaker = [] lang = [] text = [] text_input_ids = [] text_attention_mask = [] audio_mel_specs = [] audio_attention_mask = [] for x in batch: filepath.append(x["filepath"]) speaker.append(x["speaker"]) lang.append(x["lang"]) text.append(x["text"]) text_input_ids.append( torch.nn.functional.pad( x["text_input_ids"], (text_maxlen - len(x["text_input_ids"]), 0), value=self.text_pad, ) ) text_attention_mask.append( torch.nn.functional.pad( x["text_attention_mask"], (text_maxlen - len(x["text_attention_mask"]), 0), value=0, ) ) audio_mel_specs.append( torch.nn.functional.pad( x["audio_mel_specs"], (0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])), value=self.audio_pad, ) ) audio_attention_mask.append( torch.nn.functional.pad( x["audio_attention_mask"], (0, audio_maxlen - len(x["audio_attention_mask"])), value=0, ) ) return { "filepath": filepath, "speaker": speaker, "lang": lang, "text": text, "text_input_ids": torch.stack(text_input_ids), "text_attention_mask": torch.stack(text_attention_mask), "audio_mel_specs": torch.stack(audio_mel_specs), "audio_attention_mask": torch.stack(audio_attention_mask), } def formalize_xz_list(src_folder: str): for root, _, files in os.walk(src_folder): for file in files: if file.endswith(".list"): filepath = os.path.join(root, file) print(filepath) lazy_data = XzListFolder(filepath).lazy_data XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder) def concat_dataset( src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"] ) -> None: if save_folder is None: save_folder = src_folder if os.path.isfile(save_folder): raise FileExistsError(f"{save_folder} already exists as a file!") elif not os.path.exists(save_folder): os.makedirs(save_folder) lazy_data = [] same_folder = os.path.samefile(src_folder, save_folder) for root, _, files in os.walk(src_folder): for file in files: filepath = os.path.join(root, file) if same_folder and file in ("all.list", "all.json"): continue if file.endswith(".list"): print(filepath) lazy_data += ListFolder(filepath).lazy_data if file.endswith(".json"): print(filepath) lazy_data += JsonFolder(filepath).lazy_data if langs is not None: lazy_data = [item for item in lazy_data if item["lang"] in langs] ListFolder.save_config( os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder ) JsonFolder.save_config( os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder ) print(f"all.list and all.json are saved to {save_folder}")