Spaces:
Sleeping
Sleeping
import abc | |
import functools | |
import io | |
import json | |
import logging | |
import os | |
import tarfile | |
import typing | |
import torch.utils.data | |
import torchaudio | |
import transformers | |
import vocos | |
from torchvision.datasets.utils import download_url | |
from modules.ChatTTS.ChatTTS.utils.infer_utils import ( | |
apply_character_map, | |
count_invalid_characters, | |
) | |
class LazyDataType(typing.TypedDict): | |
filepath: str | |
speaker: str | |
lang: str | |
text: str | |
class DataType(LazyDataType): | |
text_input_ids: torch.Tensor # (batch_size, text_len) | |
text_attention_mask: torch.Tensor # (batch_size, text_len) | |
audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100) | |
audio_attention_mask: torch.Tensor # (batch_size, audio_len) | |
class XzListTarKwargsType(typing.TypedDict): | |
tokenizer: typing.Union[transformers.PreTrainedTokenizer, None] | |
vocos_model: typing.Union[vocos.Vocos, None] | |
device: typing.Union[str, torch.device, None] | |
speakers: typing.Union[typing.Iterable[str], None] | |
sample_rate: typing.Union[int] | |
default_speaker: typing.Union[str, None] | |
default_lang: typing.Union[str, None] | |
tar_in_memory: typing.Union[bool, None] | |
process_ahead: typing.Union[bool, None] | |
class AudioFolder(torch.utils.data.Dataset, abc.ABC): | |
def __init__( | |
self, | |
root: str | io.BytesIO, | |
tokenizer: transformers.PreTrainedTokenizer | None = None, | |
vocos_model: vocos.Vocos | None = None, | |
device: str | torch.device | None = None, | |
speakers: typing.Iterable[str] | None = None, | |
sample_rate: int = 24_000, | |
default_speaker: str | None = None, | |
default_lang: str | None = None, | |
tar_path: str | None = None, | |
tar_in_memory: bool = False, | |
process_ahead: bool = False, | |
) -> None: | |
self.root = root | |
self.sample_rate = sample_rate | |
self.default_speaker = default_speaker | |
self.default_lang = default_lang | |
self.logger = logging.getLogger(__name__) | |
self.normalizer = {} | |
self.tokenizer = tokenizer | |
self.vocos = vocos_model | |
self.vocos_device = ( | |
None if self.vocos is None else next(self.vocos.parameters()).device | |
) | |
self.device = device or self.vocos_device | |
# tar -cvf ../Xz.tar * | |
# tar -xf Xz.tar -C ./Xz | |
self.tar_path = tar_path | |
self.tar_file = None | |
self.tar_io = None | |
if tar_path is not None: | |
if tar_in_memory: | |
with open(tar_path, "rb") as f: | |
self.tar_io = io.BytesIO(f.read()) | |
self.tar_file = tarfile.open(fileobj=self.tar_io) | |
else: | |
self.tar_file = tarfile.open(tar_path) | |
self.lazy_data, self.speakers = self.get_lazy_data(root, speakers) | |
self.text_input_ids: dict[int, torch.Tensor] = {} | |
self.audio_mel_specs: dict[int, torch.Tensor] = {} | |
if process_ahead: | |
for n, item in enumerate(self.lazy_data): | |
self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"]) | |
self.text_input_ids[n] = self.preprocess_text( | |
item["text"], item["lang"] | |
) | |
if self.tar_file is not None: | |
self.tar_file.close() | |
if self.tar_io is not None: | |
self.tar_io.close() | |
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ... | |
def save_config( | |
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
) -> None: ... | |
def __len__(self): | |
return len(self.lazy_data) | |
def __getitem__(self, n: int) -> DataType: | |
lazy_data = self.lazy_data[n] | |
if n in self.audio_mel_specs: | |
audio_mel_specs = self.audio_mel_specs[n] | |
text_input_ids = self.text_input_ids[n] | |
else: | |
audio_mel_specs = self.preprocess_audio(lazy_data["filepath"]) | |
text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"]) | |
self.audio_mel_specs[n] = audio_mel_specs | |
self.text_input_ids[n] = text_input_ids | |
if len(self.audio_mel_specs) == len(self.lazy_data): | |
if self.tar_file is not None: | |
self.tar_file.close() | |
if self.tar_io is not None: | |
self.tar_io.close() | |
text_attention_mask = torch.ones( | |
len(text_input_ids), device=text_input_ids.device | |
) | |
audio_attention_mask = torch.ones( | |
(len(audio_mel_specs) + 1) // 2, | |
device=audio_mel_specs.device, | |
) | |
return { | |
"filepath": lazy_data["filepath"], | |
"speaker": lazy_data["speaker"], | |
"lang": lazy_data["lang"], | |
"text": lazy_data["text"], | |
"text_input_ids": text_input_ids, | |
"text_attention_mask": text_attention_mask, | |
"audio_mel_specs": audio_mel_specs, | |
"audio_attention_mask": audio_attention_mask, | |
} | |
def get_lazy_data( | |
self, | |
root: str | io.BytesIO, | |
speakers: typing.Iterable[str] | None = None, | |
) -> tuple[list[LazyDataType], set[str]]: | |
if speakers is not None: | |
new_speakers = set(speakers) | |
else: | |
new_speakers = set() | |
lazy_data = [] | |
raw_data = self.get_raw_data(root) | |
folder_path = os.path.dirname(root) if isinstance(root, str) else "" | |
for item in raw_data: | |
if "speaker" not in item: | |
item["speaker"] = self.default_speaker | |
if "lang" not in item: | |
item["lang"] = self.default_lang | |
if speakers is not None and item["speaker"] not in speakers: | |
continue | |
if speakers is None and item["speaker"] not in new_speakers: | |
new_speakers.add(item["speaker"]) | |
if self.tar_file is None and isinstance(root, str): | |
filepath = os.path.join(folder_path, item["filepath"]) | |
else: | |
filepath = item["filepath"] | |
lazy_data.append( | |
{ | |
"filepath": filepath, | |
"speaker": item["speaker"], | |
"lang": item["lang"].lower(), | |
"text": item["text"], | |
} | |
) | |
return lazy_data, new_speakers | |
def preprocess_text( | |
self, | |
text: str, | |
lang: str, | |
) -> torch.Tensor: | |
invalid_characters = count_invalid_characters(text) | |
if len(invalid_characters): | |
# self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') | |
text = apply_character_map(text) | |
# if not skip_refine_text: | |
# text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] | |
# text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] | |
# text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) | |
# if refine_text_only: | |
# return text | |
text = f"[Stts][spk_emb]{text}[Ptts]" | |
# text = f'[Stts][empty_spk]{text}[Ptts]' | |
text_token = self.tokenizer( | |
text, return_tensors="pt", add_special_tokens=False | |
).to(device=self.device) | |
return text_token["input_ids"].squeeze(0) | |
def preprocess_audio(self, filepath: str) -> torch.Tensor: | |
if self.tar_file is not None: | |
file = self.tar_file.extractfile(filepath) | |
waveform, sample_rate = torchaudio.load(file) | |
else: | |
waveform, sample_rate = torchaudio.load(filepath) | |
waveform = waveform.to(device=self.vocos_device) | |
if sample_rate != self.sample_rate: | |
waveform = torchaudio.functional.resample( | |
waveform, | |
orig_freq=sample_rate, | |
new_freq=self.sample_rate, | |
) | |
mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform) | |
return ( | |
mel_spec.to(device=self.device).squeeze(0).transpose(0, 1) | |
) # (audio_len*2, 100) | |
class JsonFolder(AudioFolder): | |
""" | |
In json file, each item is formatted as following example: | |
`{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`. | |
filepath is relative to the dirname of root json file. | |
""" | |
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
with open(root, "r", encoding="utf-8") as f: | |
raw_data = json.load(f) | |
return raw_data | |
def save_config( | |
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
) -> None: | |
save_data = [item.copy() for item in lazy_data] | |
for item in save_data: | |
item["filepath"] = os.path.relpath(item["filepath"], rel_path) | |
with open(save_path, "w", encoding="utf-8") as f: | |
json.dump(save_data, f, ensure_ascii=False, indent=4) | |
class ListFolder(AudioFolder): | |
""" | |
In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator. | |
`path/to/file.wav|John|ZH|Hello`. | |
filepath is relative to the dirname of root list file. | |
""" | |
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
raw_data = [] | |
with open(root, "r", encoding="utf-8") as f: | |
for line in f.readlines(): | |
line = line.strip().removesuffix("\n") | |
if len(line) == 0: | |
continue | |
filepath, speaker, lang, text = line.split(sep="|", maxsplit=3) | |
raw_data.append( | |
{ | |
"text": text, | |
"filepath": filepath, | |
"speaker": speaker, | |
"lang": lang, | |
} | |
) | |
return raw_data | |
def save_config( | |
save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
) -> None: | |
save_data = [item.copy() for item in lazy_data] | |
for item in save_data: | |
item["filepath"] = os.path.relpath(item["filepath"], rel_path) | |
with open(save_path, "w", encoding="utf-8") as f: | |
for item in save_data: | |
f.write( | |
f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n" | |
) | |
class XzListTar(ListFolder): | |
def __init__( | |
self, | |
*args, | |
root: str | io.BytesIO, | |
tar_path: str | None = None, | |
**kwargs, | |
): | |
if isinstance(root, io.BytesIO): | |
assert tar_path is not None | |
else: | |
# make sure root is a list file | |
if not root.endswith(".list"): # folder case | |
if os.path.isfile(root): | |
raise FileExistsError(f"{root} is a file!") | |
elif not os.path.exists(root): | |
os.makedirs(root) | |
root = os.path.join(root, "all.list") | |
if isinstance(root, str) and not os.path.isfile(root): | |
# prepare all.list | |
self.concat_dataset( | |
save_folder=os.path.dirname(root), | |
langs=kwargs.get("langs", ["zh", "en"]), | |
) | |
super().__init__(root, *args, tar_path=tar_path, **kwargs) | |
def concat_dataset( | |
self, save_folder: str | None = None, langs: list[str] = ["zh", "en"] | |
) -> None: | |
if save_folder is None: | |
save_folder = os.path.dirname(self.root) | |
if os.path.isfile(save_folder): | |
raise FileExistsError(f"{save_folder} already exists as a file!") | |
elif not os.path.exists(save_folder): | |
os.makedirs(save_folder) | |
lazy_data = [] | |
for member in self.tar_file.getmembers(): | |
if not member.isfile(): | |
continue | |
if member.name.endswith(".list"): | |
print(member.name) | |
root_io = self.tar_file.extractfile(member) | |
lazy_data += ListFolder(root_io).lazy_data | |
if member.name.endswith(".json"): | |
print(member.name) | |
root_io = self.tar_file.extractfile(member) | |
lazy_data += JsonFolder(root_io).lazy_data | |
if langs is not None: | |
lazy_data = [item for item in lazy_data if item["lang"] in langs] | |
ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data) | |
JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data) | |
print(f"all.list and all.json are saved to {save_folder}") | |
class XzListFolder(ListFolder): | |
""" | |
[XzδΉεΈ](https://space.bilibili.com/5859321) | |
Only look at the basename of filepath in list file. Previous folder paths are ignored. | |
Files are organized as `[list basename]/[file basename]` | |
Example tree structure: | |
[folder] | |
βββ speaker_A | |
β βββ 1.wav | |
β βββ 2.wav | |
βββ speaker_A.list | |
βββ speaker_B | |
β βββ 1.wav | |
β βββ 2.wav | |
βββ speaker_B.list | |
""" | |
def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
raw_data = super().get_raw_data(root) | |
for item in raw_data: | |
item["filepath"] = os.path.join( | |
os.path.basename(root).removesuffix(".list"), | |
os.path.basename(item["filepath"]), | |
) | |
return raw_data | |
class AudioCollator: | |
def __init__(self, text_pad: int = 0, audio_pad: int = 0): | |
self.text_pad = text_pad | |
self.audio_pad = audio_pad | |
def __call__(self, batch: list[DataType]): | |
batch = [x for x in batch if x is not None] | |
audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch) | |
text_maxlen = max(len(item["text_attention_mask"]) for item in batch) | |
filepath = [] | |
speaker = [] | |
lang = [] | |
text = [] | |
text_input_ids = [] | |
text_attention_mask = [] | |
audio_mel_specs = [] | |
audio_attention_mask = [] | |
for x in batch: | |
filepath.append(x["filepath"]) | |
speaker.append(x["speaker"]) | |
lang.append(x["lang"]) | |
text.append(x["text"]) | |
text_input_ids.append( | |
torch.nn.functional.pad( | |
x["text_input_ids"], | |
(text_maxlen - len(x["text_input_ids"]), 0), | |
value=self.text_pad, | |
) | |
) | |
text_attention_mask.append( | |
torch.nn.functional.pad( | |
x["text_attention_mask"], | |
(text_maxlen - len(x["text_attention_mask"]), 0), | |
value=0, | |
) | |
) | |
audio_mel_specs.append( | |
torch.nn.functional.pad( | |
x["audio_mel_specs"], | |
(0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])), | |
value=self.audio_pad, | |
) | |
) | |
audio_attention_mask.append( | |
torch.nn.functional.pad( | |
x["audio_attention_mask"], | |
(0, audio_maxlen - len(x["audio_attention_mask"])), | |
value=0, | |
) | |
) | |
return { | |
"filepath": filepath, | |
"speaker": speaker, | |
"lang": lang, | |
"text": text, | |
"text_input_ids": torch.stack(text_input_ids), | |
"text_attention_mask": torch.stack(text_attention_mask), | |
"audio_mel_specs": torch.stack(audio_mel_specs), | |
"audio_attention_mask": torch.stack(audio_attention_mask), | |
} | |
def formalize_xz_list(src_folder: str): | |
for root, _, files in os.walk(src_folder): | |
for file in files: | |
if file.endswith(".list"): | |
filepath = os.path.join(root, file) | |
print(filepath) | |
lazy_data = XzListFolder(filepath).lazy_data | |
XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder) | |
def concat_dataset( | |
src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"] | |
) -> None: | |
if save_folder is None: | |
save_folder = src_folder | |
if os.path.isfile(save_folder): | |
raise FileExistsError(f"{save_folder} already exists as a file!") | |
elif not os.path.exists(save_folder): | |
os.makedirs(save_folder) | |
lazy_data = [] | |
same_folder = os.path.samefile(src_folder, save_folder) | |
for root, _, files in os.walk(src_folder): | |
for file in files: | |
filepath = os.path.join(root, file) | |
if same_folder and file in ("all.list", "all.json"): | |
continue | |
if file.endswith(".list"): | |
print(filepath) | |
lazy_data += ListFolder(filepath).lazy_data | |
if file.endswith(".json"): | |
print(filepath) | |
lazy_data += JsonFolder(filepath).lazy_data | |
if langs is not None: | |
lazy_data = [item for item in lazy_data if item["lang"] in langs] | |
ListFolder.save_config( | |
os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder | |
) | |
JsonFolder.save_config( | |
os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder | |
) | |
print(f"all.list and all.json are saved to {save_folder}") | |