sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
6de3e11
raw
history blame
12.1 kB
import json
import os
import random
import sys
from typing import List
import librosa
import numpy as np
import soundfile
from torch.utils.data import Dataset
from tqdm import tqdm
from utils.binary import DatasetReader
class CustomDataset(Dataset):
def __init__(self,
data_list_path,
processor,
mono=True,
language=None,
timestamps=False,
sample_rate=16000,
min_duration=0.5,
max_duration=30,
augment_config_path=None):
"""
Args:
data_list_path:
processor: Whisper
mono: True
language:
timestamps:
sample_rate: 16000
min_duration: 0.5s
max_duration: 30s
augment_config_path:
"""
super(CustomDataset, self).__init__()
assert min_duration >= 0.5, f"min_duration 0.5:{min_duration}"
assert max_duration <= 30, f"max_duration 30:{max_duration}"
self.data_list_path = data_list_path
self.processor = processor
self.data_list_path = data_list_path
self.sample_rate = sample_rate
self.mono = mono
self.language = language
self.timestamps = timestamps
self.min_duration = min_duration
self.max_duration = max_duration
self.vocab = self.processor.tokenizer.get_vocab()
self.timestamp_begin = self.vocab['<|notimestamps|>'] + 1
self.startoftranscript = self.vocab['<|startoftranscript|>']
self.endoftext = self.vocab['<|endoftext|>']
self.nocaptions = self.vocab['<|nocaptions|>']
self.data_list: List[dict] = []
#
self._load_data_list()
#
self.augment_configs = None
self.noises_path = None
self.speed_rates = None
if augment_config_path:
with open(augment_config_path, 'r', encoding='utf-8') as f:
self.augment_configs = json.load(f)
#
def _load_data_list(self):
if self.data_list_path.endswith(".header"):
#
self.dataset_reader = DatasetReader(data_header_path=self.data_list_path,
min_duration=self.min_duration,
max_duration=self.max_duration)
self.data_list = self.dataset_reader.get_keys()
else:
#
with open(self.data_list_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
self.data_list = []
for line in tqdm(lines, desc=''):
if isinstance(line, str):
line = json.loads(line)
if not isinstance(line, dict): continue
#
if line["duration"] < self.min_duration:
continue
if self.max_duration != -1 and line["duration"] > self.max_duration:
continue
self.data_list.append(dict(line))
#
def _get_list_data(self, idx):
if self.data_list_path.endswith(".header"):
data_list = self.dataset_reader.get_data(self.data_list[idx])
else:
data_list = self.data_list[idx]
#
audio_file = data_list["audio"]['path']
transcript = data_list["sentences"] if self.timestamps else data_list["sentence"]
language = data_list["language"] if 'language' in data_list.keys() else None
if 'start_time' not in data_list["audio"].keys():
sample, sample_rate = soundfile.read(audio_file, dtype='float32')
else:
start_time, end_time = data_list["audio"]["start_time"], data_list["audio"]["end_time"]
#
sample, sample_rate = self.slice_from_file(audio_file, start=start_time, end=end_time)
sample = sample.T
#
if self.mono:
sample = librosa.to_mono(sample)
#
if self.augment_configs:
sample, sample_rate = self.augment(sample, sample_rate)
#
if self.sample_rate != sample_rate:
sample = self.resample(sample, orig_sr=sample_rate, target_sr=self.sample_rate)
return sample, sample_rate, transcript, language
def _load_timestamps_transcript(self, transcript: List[dict]):
assert isinstance(transcript, list), f"transcript list:{type(transcript)}"
data = dict()
labels = self.processor.tokenizer.prefix_tokens[:3]
for t in transcript:
#
start = t['start'] if round(t['start'] * 100) % 2 == 0 else t['start'] + 0.01
start = self.timestamp_begin + round(start * 100) // 2
end = t['end'] if round(t['end'] * 100) % 2 == 0 else t['end'] - 0.01
end = self.timestamp_begin + round(end * 100) // 2
label = self.processor(text=t['text']).input_ids[4:-1]
labels.extend([start])
labels.extend(label)
labels.extend([end])
data['labels'] = labels + [self.endoftext]
return data
def __getitem__(self, idx):
try:
#
sample, sample_rate, transcript, language = self._get_list_data(idx=idx)
#
self.processor.tokenizer.set_prefix_tokens(language=language if language is not None else self.language)
if len(transcript) > 0:
#
if self.timestamps:
data = self._load_timestamps_transcript(transcript=transcript)
#
data["input_features"] = self.processor(audio=sample, sampling_rate=self.sample_rate).input_features
else:
#
data = self.processor(audio=sample, sampling_rate=self.sample_rate, text=transcript)
else:
#
data = self.processor(audio=sample, sampling_rate=self.sample_rate)
data['labels'] = [self.startoftranscript, self.nocaptions, self.endoftext]
return data
except Exception as e:
print(f'idx:{idx} error - {e}', file=sys.stderr)
return self.__getitem__(random.randint(0, self.__len__() - 1))
def __len__(self):
return len(self.data_list)
#
@staticmethod
def slice_from_file(file, start, end):
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = round(float(len(sndfile)) / sample_rate, 3)
start = round(start, 3)
end = round(end, 3)
#
if start < 0.0: start += duration
if end < 0.0: end += duration
#
if start < 0.0: start = 0.0
if end > duration: end = duration
if end < 0.0:
raise ValueError("(%f s)" % end)
if start > end:
raise ValueError("(%f s)(%f s)" % (start, end))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
sample = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return sample, sample_rate
#
def augment(self, sample, sample_rate):
for config in self.augment_configs:
if config['type'] == 'speed' and random.random() < config['prob']:
if self.speed_rates is None:
min_speed_rate, max_speed_rate, num_rates = config['params']['min_speed_rate'], \
config['params']['max_speed_rate'], config['params']['num_rates']
self.speed_rates = np.linspace(min_speed_rate, max_speed_rate, num_rates, endpoint=True)
rate = random.choice(self.speed_rates)
sample = self.change_speed(sample, speed_rate=rate)
if config['type'] == 'shift' and random.random() < config['prob']:
min_shift_ms, max_shift_ms = config['params']['min_shift_ms'], config['params']['max_shift_ms']
shift_ms = random.randint(min_shift_ms, max_shift_ms)
sample = self.shift(sample, sample_rate, shift_ms=shift_ms)
if config['type'] == 'volume' and random.random() < config['prob']:
min_gain_dBFS, max_gain_dBFS = config['params']['min_gain_dBFS'], config['params']['max_gain_dBFS']
gain = random.randint(min_gain_dBFS, max_gain_dBFS)
sample = self.volume(sample, gain=gain)
if config['type'] == 'resample' and random.random() < config['prob']:
new_sample_rates = config['params']['new_sample_rates']
new_sample_rate = np.random.choice(new_sample_rates)
sample = self.resample(sample, orig_sr=sample_rate, target_sr=new_sample_rate)
sample_rate = new_sample_rate
if config['type'] == 'noise' and random.random() < config['prob']:
min_snr_dB, max_snr_dB = config['params']['min_snr_dB'], config['params']['max_snr_dB']
if self.noises_path is None:
self.noises_path = []
noise_dir = config['params']['noise_dir']
if os.path.exists(noise_dir):
for file in os.listdir(noise_dir):
self.noises_path.append(os.path.join(noise_dir, file))
noise_path = random.choice(self.noises_path)
snr_dB = random.randint(min_snr_dB, max_snr_dB)
sample = self.add_noise(sample, sample_rate, noise_path=noise_path, snr_dB=snr_dB)
return sample, sample_rate
#
@staticmethod
def change_speed(sample, speed_rate):
if speed_rate == 1.0:
return sample
if speed_rate <= 0:
raise ValueError("error")
old_length = sample.shape[0]
new_length = int(old_length / speed_rate)
old_indices = np.arange(old_length)
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
sample = np.interp(new_indices, old_indices, sample).astype(np.float32)
return sample
#
@staticmethod
def shift(sample, sample_rate, shift_ms):
duration = sample.shape[0] / sample_rate
if abs(shift_ms) / 1000.0 > duration:
raise ValueError("shift_ms")
shift_samples = int(shift_ms * sample_rate / 1000)
if shift_samples > 0:
sample[:-shift_samples] = sample[shift_samples:]
sample[-shift_samples:] = 0
elif shift_samples < 0:
sample[-shift_samples:] = sample[:shift_samples]
sample[:-shift_samples] = 0
return sample
#
@staticmethod
def volume(sample, gain):
sample *= 10.**(gain / 20.)
return
#
@staticmethod
def resample(sample, orig_sr, target_sr):
sample = librosa.resample(sample, orig_sr=orig_sr, target_sr=target_sr)
return sample
#
def add_noise(self, sample, sample_rate, noise_path, snr_dB, max_gain_db=300.0):
noise_sample, sr = librosa.load(noise_path, sr=sample_rate)
#
target_db = -20
gain = min(max_gain_db, target_db - self.rms_db(sample))
sample *= 10. ** (gain / 20.)
#
sample_rms_db, noise_rms_db = self.rms_db(sample), self.rms_db(noise_sample)
noise_gain_db = min(sample_rms_db - noise_rms_db - snr_dB, max_gain_db)
noise_sample *= 10. ** (noise_gain_db / 20.)
#
if noise_sample.shape[0] < sample.shape[0]:
diff_duration = sample.shape[0] - noise_sample.shape[0]
noise_sample = np.pad(noise_sample, (0, diff_duration), 'wrap')
elif noise_sample.shape[0] > sample.shape[0]:
start_frame = random.randint(0, noise_sample.shape[0] - sample.shape[0])
noise_sample = noise_sample[start_frame:sample.shape[0] + start_frame]
sample += noise_sample
return sample
@staticmethod
def rms_db(sample):
mean_square = np.mean(sample ** 2)
return 10 * np.log10(mean_square)