Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import random | |
import json | |
import tarfile | |
import json | |
import io | |
import pyarrow.parquet as pq | |
from io import BytesIO | |
import torch | |
import torchaudio | |
from torch.nn.utils.rnn import pad_sequence | |
import torch.nn.functional as F | |
import tarfile | |
import json | |
import io | |
import wave | |
import numpy as np | |
import torchaudio | |
import os | |
import sys | |
import json | |
import random | |
import pickle | |
import argparse | |
import itertools | |
import mmap | |
import struct | |
import collections | |
import shutil | |
import multiprocessing as mp | |
from pathlib import Path | |
from tqdm import tqdm | |
from collections import defaultdict | |
from copy import deepcopy | |
from datetime import datetime | |
import pickle | |
from wids import wids | |
import math | |
torchaudio.set_audio_backend('soundfile') | |
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) | |
try: | |
MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt") | |
GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt") | |
except: | |
MAIN_SPK_EMBEDDING=torch.zeros(1,192) | |
GPT_SPK_EMBEDDING=torch.zeros(1,192) | |
def parquet_opener(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
url = sample['src'] | |
try: | |
df = pq.read_table(url).to_pandas() | |
for i in range(len(df)): | |
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data: | |
continue | |
sample.update(dict(df.loc[i])) | |
if mode == 'train': | |
# NOTE do not return sample directly, must initialize a new dict | |
yield {**sample} | |
else: | |
for index, text in enumerate(tts_data[df.loc[i, 'utt']]): | |
yield {**sample, 'tts_index': index, 'tts_text': text} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(url, ex)) | |
def parse_tar_header(header_bytes): | |
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) | |
return TarHeader(*header) | |
TarHeader = collections.namedtuple( | |
"TarHeader", | |
[ | |
"name", | |
"mode", | |
"uid", | |
"gid", | |
"size", | |
"mtime", | |
"chksum", | |
"typeflag", | |
"linkname", | |
"magic", | |
"version", | |
"uname", | |
"gname", | |
"devmajor", | |
"devminor", | |
"prefix", | |
], | |
) | |
class MMTar: | |
def __init__(self, file_path: Path | str): | |
self.stream = open(file_path, "rb") | |
self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) | |
def __del__(self): | |
try: | |
self.mmap.close() | |
self.stream.close() | |
except: # noqa | |
pass | |
def get_at_offset(self, offset) -> tuple[str, bytes]: | |
header = parse_tar_header(self.mmap[offset : offset + 500]) | |
name = header.name.decode("utf-8").strip("\x00") | |
start = offset + 512 | |
end = start + int(header.size.decode("utf-8")[:-1], 8) | |
return name, self.mmap[start:end] | |
class Tar: | |
def __init__(self, path: Path): | |
self.tar = MMTar(path) | |
indices_path = path.with_suffix(".index") | |
self.index = pickle.loads(indices_path.read_bytes()) | |
self.name_mapping = {} | |
for name, offset, _ in self.index: | |
self.name_mapping[name] = offset | |
def read(self, name: str) -> bytes: | |
return self.tar.get_at_offset(self.name_mapping[name])[1] | |
def cosy_jsonl_opener(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
cosy_jsonl_path = sample['src'] | |
tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar") | |
try: | |
tar_data=Tar(Path(tar_file_path)) | |
with open(cosy_jsonl_path, 'r') as f: | |
for line in f: | |
item=json.loads(line) | |
cosy_token = item['cosy_token'] | |
sample['speech_token']=torch.tensor(cosy_token) | |
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) | |
# print(item['filename']) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) | |
def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
cosy_jsonl_path = sample['src'] | |
tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar") | |
try: | |
tar_data=Tar(Path(tar_file_path)) | |
with open(cosy_jsonl_path, 'r') as f: | |
# cosy_data = [json.loads(line) for line in f] | |
for line in f: | |
item=json.loads(line) | |
cosy_token = item['cosy_token'] | |
sample['speech_token']=torch.tensor(cosy_token) | |
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) | |
# print(item['filename']) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) | |
def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
cosy_jsonl_path = sample['src'] | |
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar") | |
try: | |
tar_data=Tar(Path(tar_file_path)) | |
with open(cosy_jsonl_path, 'r') as f: | |
for line in f: | |
item=json.loads(line) | |
cosy_token = item['cosy_token'] | |
sample['speech_token']=torch.tensor(cosy_token) | |
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) | |
def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
cosy_jsonl_path = sample['src'] | |
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar") | |
try: | |
tar_data=Tar(Path(tar_file_path)) | |
with open(cosy_jsonl_path, 'r') as f: | |
# cosy_data = [json.loads(line) for line in f] | |
for line in f: | |
item=json.loads(line) | |
cosy_token = item['cosy_token'] | |
sample['speech_token']=torch.tensor(cosy_token) | |
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) | |
# print(item['filename']) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) | |
def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}): | |
""" Give url or local file, return file descriptor | |
Inplace operation. | |
Args: | |
data(Iterable[str]): url or local file list | |
Returns: | |
Iterable[{src, stream}] | |
""" | |
for sample in data: | |
assert 'src' in sample | |
cosy_jsonl_path = sample['src'] | |
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar") | |
try: | |
tar_data=Tar(Path(tar_file_path)) | |
with open(cosy_jsonl_path, 'r') as f: | |
# cosy_data = [json.loads(line) for line in f] | |
for line in f: | |
item=json.loads(line) | |
cosy_token = item['cosy_token'] | |
sample['speech_token']=torch.tensor(cosy_token) | |
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename']))) | |
# print(item['filename']) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex)) | |
def process_sft_vq0918_pool4(data, mode='train', tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
token_npy_path = sample['src'] | |
wav_path=token_npy_path.replace(".vq0918-pool4.npy","") | |
# wav_path,token_npy_path=sample['src'].split(' ') | |
try: | |
sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
if sample['speech'].shape[0] > 1: | |
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
token_npy_path = sample['src'] | |
wav_path=token_npy_path.replace(".vq0918-pool4.npy","") | |
# wav_path,token_npy_path=sample['src'].split(' ') | |
try: | |
# sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
# if sample['speech'].shape[0] > 1: | |
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
speech_token=torch.tensor(np.load(token_npy_path)) | |
speech,sample_rate= torchaudio.load(wav_path) | |
# split_speech=int(split_token / 12.5 * sample_rate) | |
if speech.shape[0] > 1: | |
speech = speech.mean(dim=0, keepdim=True) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
sample['sample_rate']=sample_rate | |
num_splits = (speech_token.size(0) + split_token - 1) // split_token | |
for split_id in range(num_splits): | |
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) | |
end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate)) | |
sample['speech_token']=speech_token[:end_token_idx] | |
sample['speech']=speech[:,:end_speech_idx] | |
print(sample['speech_token'].size(),sample['speech'].size()) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def process_sft_vq0918_pool2(data, mode='train', tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy") | |
wav_path=token_npy_path.replace(".vq0918-pool2.npy","") | |
# wav_path,token_npy_path=sample['src'].split(' ') | |
try: | |
sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
if sample['speech'].shape[0] > 1: | |
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
token_npy_path = sample['src'] | |
wav_path=token_npy_path.replace(".vq0918-pool2.npy","") | |
# wav_path,token_npy_path=sample['src'].split(' ') | |
try: | |
# sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
# if sample['speech'].shape[0] > 1: | |
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
speech_token=torch.tensor(np.load(token_npy_path)) | |
speech,sample_rate= torchaudio.load(wav_path) | |
# split_speech=int(split_token / 12.5 * sample_rate) | |
if speech.shape[0] > 1: | |
speech = speech.mean(dim=0, keepdim=True) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
sample['sample_rate']=sample_rate | |
num_splits = (speech_token.size(0) + split_token - 1) // split_token | |
for split_id in range(num_splits): | |
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0)) | |
end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate)) | |
sample['speech_token']=speech_token[:end_token_idx] | |
sample['speech']=speech[:,:end_speech_idx] | |
print(sample['speech_token'].size(),sample['speech'].size()) | |
yield {**sample} | |
except Exception as ex: | |
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
try: | |
entry=json.loads(sample['src']) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
for conv in entry["conversations"]: | |
if "response_wav" in conv: | |
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" | |
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") | |
sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
if sample['speech'].shape[0] > 1: | |
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
sample['spk_embedding']=spk_embedding | |
yield {**sample} | |
except Exception as ex: | |
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}): | |
for sample in data: | |
assert 'src' in sample | |
try: | |
entry=json.loads(sample['src']) | |
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING) | |
for conv in entry["conversations"]: | |
if "response_wav" in conv: | |
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" | |
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") | |
sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
if sample['speech'].shape[0] > 1: | |
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
sample['spk_embedding']=spk_embedding | |
yield {**sample} | |
if "prompt_wav" in conv: | |
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}" | |
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy") | |
sample['speech_token']=torch.tensor(np.load(token_npy_path)) | |
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path) | |
if sample['speech'].shape[0] > 1: | |
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True) | |
sample['spk_embedding']=spk_embedding | |
yield {**sample} | |
except Exception as ex: | |
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex)) | |
logging.warning('Failed to open {}'.format(wav_path)) | |
def filter(data, | |
max_length=10240, | |
min_length=10, | |
token_max_length=200, | |
token_min_length=1, | |
min_output_input_ratio=0.0005, | |
max_output_input_ratio=1, | |
mode='train'): | |
""" Filter sample according to feature and label length | |
Inplace operation. | |
Args:: | |
data: Iterable[{key, wav, label, sample_rate}] | |
max_length: drop utterance which is greater than max_length(10ms) | |
min_length: drop utterance which is less than min_length(10ms) | |
token_max_length: drop utterance which is greater than | |
token_max_length, especially when use char unit for | |
english modeling | |
token_min_length: drop utterance which is | |
less than token_max_length | |
min_output_input_ratio: minimal ration of | |
token_length / feats_length(10ms) | |
max_output_input_ratio: maximum ration of | |
token_length / feats_length(10ms) | |
Returns: | |
Iterable[{key, wav, label, sample_rate}] | |
""" | |
for sample in data: | |
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) | |
# del sample['audio_data'] | |
# sample['wav'] is torch.Tensor, we have 100 frames every second | |
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 | |
if num_frames < min_length: | |
continue | |
if num_frames > max_length: | |
continue | |
if len(sample['text_token']) < token_min_length: | |
continue | |
if len(sample['text_token']) > token_max_length: | |
continue | |
if len(sample['speech_token']) == 0: | |
continue | |
if num_frames != 0: | |
if len(sample['text_token']) / num_frames < min_output_input_ratio: | |
continue | |
if len(sample['text_token']) / num_frames > max_output_input_ratio: | |
continue | |
yield sample | |
def filter_speech_token(data, | |
max_length=10240, | |
min_length=10, | |
token_max_length=5000, | |
token_min_length=1, | |
min_output_input_ratio=0.0005, | |
max_output_input_ratio=30, | |
mode='train'): | |
""" Filter sample according to feature and label length | |
Inplace operation. | |
Args:: | |
data: Iterable[{key, wav, label, sample_rate}] | |
max_length: drop utterance which is greater than max_length(10ms) | |
min_length: drop utterance which is less than min_length(10ms) | |
token_max_length: drop utterance which is greater than | |
token_max_length, especially when use char unit for | |
english modeling | |
token_min_length: drop utterance which is | |
less than token_max_length | |
min_output_input_ratio: minimal ration of | |
token_length / feats_length(10ms) | |
max_output_input_ratio: maximum ration of | |
token_length / feats_length(10ms) | |
Returns: | |
Iterable[{key, wav, label, sample_rate}] | |
""" | |
for sample in data: | |
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data'])) | |
# del sample['audio_data'] | |
# sample['wav'] is torch.Tensor, we have 100 frames every second | |
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100 | |
if num_frames < min_length: | |
continue | |
if num_frames > max_length: | |
continue | |
if len(sample['speech_token']) < token_min_length: | |
continue | |
if len(sample['speech_token']) > token_max_length: | |
continue | |
if len(sample['speech_token']) == 0: | |
continue | |
if num_frames != 0: | |
if len(sample['speech_token']) / num_frames < min_output_input_ratio: | |
continue | |
if len(sample['speech_token']) / num_frames > max_output_input_ratio: | |
continue | |
yield sample | |
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'): | |
""" Resample data. | |
Inplace operation. | |
Args: | |
data: Iterable[{key, wav, label, sample_rate}] | |
resample_rate: target resample rate | |
Returns: | |
Iterable[{key, wav, label, sample_rate}] | |
""" | |
for sample in data: | |
assert 'sample_rate' in sample | |
assert 'speech' in sample | |
sample_rate = sample['sample_rate'] | |
waveform = sample['speech'] | |
if sample_rate != resample_rate: | |
if sample_rate < min_sample_rate: | |
continue | |
sample['sample_rate'] = resample_rate | |
sample['speech'] = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, new_freq=resample_rate)(waveform) | |
max_val = sample['speech'].abs().max() | |
if max_val > 1: | |
sample['speech'] /= max_val | |
yield sample | |
def compute_fbank(data, | |
feat_extractor, | |
mode='train'): | |
""" Extract fbank | |
Args: | |
data: Iterable[{key, wav, label, sample_rate}] | |
Returns: | |
Iterable[{key, feat, label}] | |
""" | |
for sample in data: | |
assert 'sample_rate' in sample | |
assert 'speech' in sample | |
# assert 'utt' in sample | |
# assert 'text_token' in sample | |
waveform = sample['speech'] | |
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1) | |
sample['speech_feat'] = mat | |
del sample['speech'] | |
yield sample | |
def parse_embedding(data, normalize, mode='train'): | |
""" Parse utt_embedding/spk_embedding | |
Args: | |
data: Iterable[{key, wav, label, sample_rate}] | |
Returns: | |
Iterable[{key, feat, label}] | |
""" | |
for sample in data: | |
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32) | |
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32) | |
if normalize: | |
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0) | |
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0) | |
yield sample | |
def tokenize(data, get_tokenizer, allowed_special, mode='train'): | |
""" Decode text to chars or BPE | |
Inplace operation | |
Args: | |
data: Iterable[{key, wav, txt, sample_rate}] | |
Returns: | |
Iterable[{key, wav, txt, tokens, label, sample_rate}] | |
""" | |
tokenizer = get_tokenizer() | |
for sample in data: | |
assert 'text' in sample | |
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special) | |
if mode == 'inference': | |
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special) | |
yield sample | |
def shuffle(data, shuffle_size=10000, mode='train'): | |
""" Local shuffle the data | |
Args: | |
data: Iterable[{key, feat, label}] | |
shuffle_size: buffer size for shuffle | |
Returns: | |
Iterable[{key, feat, label}] | |
""" | |
buf = [] | |
for sample in data: | |
buf.append(sample) | |
if len(buf) >= shuffle_size: | |
random.shuffle(buf) | |
for x in buf: | |
yield x | |
buf = [] | |
# The sample left over | |
random.shuffle(buf) | |
for x in buf: | |
yield x | |
def sort(data, sort_size=500, mode='train'): | |
""" Sort the data by feature length. | |
Sort is used after shuffle and before batch, so we can group | |
utts with similar lengths into a batch, and `sort_size` should | |
be less than `shuffle_size` | |
Args: | |
data: Iterable[{key, feat, label}] | |
sort_size: buffer size for sort | |
Returns: | |
Iterable[{key, feat, label}] | |
""" | |
buf = [] | |
for sample in data: | |
buf.append(sample) | |
if len(buf) >= sort_size: | |
buf.sort(key=lambda x: x['speech_feat'].size(0)) | |
for x in buf: | |
yield x | |
buf = [] | |
# The sample left over | |
buf.sort(key=lambda x: x['speech_feat'].size(0)) | |
for x in buf: | |
yield x | |
def static_batch(data, batch_size=16): | |
""" Static batch the data by `batch_size` | |
Args: | |
data: Iterable[{key, feat, label}] | |
batch_size: batch size | |
Returns: | |
Iterable[List[{key, feat, label}]] | |
""" | |
buf = [] | |
for sample in data: | |
buf.append(sample) | |
if len(buf) >= batch_size: | |
yield buf | |
buf = [] | |
if len(buf) > 0: | |
yield buf | |
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'): | |
""" Dynamic batch the data until the total frames in batch | |
reach `max_frames_in_batch` | |
Args: | |
data: Iterable[{key, feat, label}] | |
max_frames_in_batch: max_frames in one batch | |
Returns: | |
Iterable[List[{key, feat, label}]] | |
""" | |
buf = [] | |
longest_frames = 0 | |
for sample in data: | |
assert 'speech_feat' in sample | |
assert isinstance(sample['speech_feat'], torch.Tensor) | |
new_sample_frames = sample['speech_feat'].size(0) | |
longest_frames = max(longest_frames, new_sample_frames) | |
frames_after_padding = longest_frames * (len(buf) + 1) | |
if frames_after_padding > max_frames_in_batch: | |
yield buf | |
buf = [sample] | |
longest_frames = new_sample_frames | |
else: | |
buf.append(sample) | |
if len(buf) > 0: | |
yield buf | |
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'): | |
""" Wrapper for static/dynamic batch | |
""" | |
if mode == 'inference': | |
return static_batch(data, 1) | |
else: | |
if batch_type == 'static': | |
return static_batch(data, batch_size) | |
elif batch_type == 'dynamic': | |
return dynamic_batch(data, max_frames_in_batch) | |
else: | |
logging.fatal('Unsupported batch type {}'.format(batch_type)) | |
def padding(data, use_spk_embedding, mode='train'): | |
""" Padding the data into training data | |
Args: | |
data: Iterable[List[{key, feat, label}]] | |
Returns: | |
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
""" | |
for sample in data: | |
assert isinstance(sample, list) | |
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], | |
dtype=torch.int32) | |
order = torch.argsort(speech_feat_len, descending=True) | |
utts = [sample[i]['utt'] for i in order] | |
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] | |
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) | |
speech_token = pad_sequence(speech_token, | |
batch_first=True, | |
padding_value=0) | |
speech_feat = [sample[i]['speech_feat'] for i in order] | |
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) | |
speech_feat = pad_sequence(speech_feat, | |
batch_first=True, | |
padding_value=0) | |
text = [sample[i]['text'] for i in order] | |
text_token = [torch.tensor(sample[i]['text_token']) for i in order] | |
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32) | |
text_token = pad_sequence(text_token, batch_first=True, padding_value=0) | |
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0) | |
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) | |
batch = { | |
"utts": utts, | |
"speech_token": speech_token, | |
"speech_token_len": speech_token_len, | |
"speech_feat": speech_feat, | |
"speech_feat_len": speech_feat_len, | |
"text": text, | |
"text_token": text_token, | |
"text_token_len": text_token_len, | |
"utt_embedding": utt_embedding, | |
"spk_embedding": spk_embedding, | |
} | |
if mode == 'inference': | |
tts_text = [sample[i]['tts_text'] for i in order] | |
tts_index = [sample[i]['tts_index'] for i in order] | |
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] | |
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) | |
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) | |
batch.update({'tts_text': tts_text, | |
'tts_index': tts_index, | |
'tts_text_token': tts_text_token, | |
'tts_text_token_len': tts_text_token_len}) | |
if use_spk_embedding is True: | |
batch["embedding"] = batch["spk_embedding"] | |
else: | |
batch["embedding"] = batch["utt_embedding"] | |
yield batch | |
def padding_speech_token(data, use_spk_embedding, mode='train'): | |
""" Padding the data into training data | |
Args: | |
data: Iterable[List[{key, feat, label}]] | |
Returns: | |
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
""" | |
for sample in data: | |
assert isinstance(sample, list) | |
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], | |
dtype=torch.int32) | |
order = torch.argsort(speech_feat_len, descending=True) | |
# utts = [sample[i]['utt'] for i in order] | |
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] | |
try: | |
speech_token = [sample[i]['speech_token'].clone().detach() for i in order] | |
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) | |
speech_token = pad_sequence(speech_token, | |
batch_first=True, | |
padding_value=0) | |
speech_feat = [sample[i]['speech_feat'] for i in order] | |
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) | |
speech_feat = pad_sequence(speech_feat, | |
batch_first=True, | |
padding_value=0) | |
batch = { | |
"speech_token": speech_token, | |
"speech_token_len": speech_token_len, | |
"speech_feat": speech_feat, | |
"speech_feat_len": speech_feat_len, | |
} | |
if mode == 'inference': | |
tts_text = [sample[i]['tts_text'] for i in order] | |
tts_index = [sample[i]['tts_index'] for i in order] | |
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] | |
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) | |
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) | |
batch.update({'tts_text': tts_text, | |
'tts_index': tts_index, | |
'tts_text_token': tts_text_token, | |
'tts_text_token_len': tts_text_token_len}) | |
# if use_spk_embedding is True: | |
# batch["embedding"] = batch["spk_embedding"] | |
# else: | |
# batch["embedding"] = batch["utt_embedding"] | |
batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) | |
yield batch | |
except Exception as ex: | |
logging.warning(' ex info {}'.format(ex)) | |
# assert False | |
def padding_speech_token_spk(data, use_spk_embedding, mode='train'): | |
""" Padding the data into training data | |
Args: | |
data: Iterable[List[{key, feat, label}]] | |
Returns: | |
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] | |
""" | |
for sample in data: | |
assert isinstance(sample, list) | |
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample], | |
dtype=torch.int32) | |
order = torch.argsort(speech_feat_len, descending=True) | |
# utts = [sample[i]['utt'] for i in order] | |
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order] | |
try: | |
speech_token = [sample[i]['speech_token'].clone().detach() for i in order] | |
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32) | |
speech_token = pad_sequence(speech_token, | |
batch_first=True, | |
padding_value=0) | |
speech_feat = [sample[i]['speech_feat'] for i in order] | |
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32) | |
speech_feat = pad_sequence(speech_feat, | |
batch_first=True, | |
padding_value=0) | |
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0) | |
batch = { | |
"speech_token": speech_token, | |
"speech_token_len": speech_token_len, | |
"speech_feat": speech_feat, | |
"speech_feat_len": speech_feat_len, | |
"spk_embedding": spk_embedding, | |
} | |
if mode == 'inference': | |
tts_text = [sample[i]['tts_text'] for i in order] | |
tts_index = [sample[i]['tts_index'] for i in order] | |
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order] | |
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32) | |
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1) | |
batch.update({'tts_text': tts_text, | |
'tts_index': tts_index, | |
'tts_text_token': tts_text_token, | |
'tts_text_token_len': tts_text_token_len}) | |
# if use_spk_embedding is True: | |
# batch["embedding"] = batch["spk_embedding"] | |
# else: | |
# batch["embedding"] = batch["utt_embedding"] | |
# batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device) | |
batch["embedding"] = batch["spk_embedding"] | |
yield batch | |
except Exception as ex: | |
logging.warning(' ex info {}'.format(ex)) | |
# assert False |