# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import sys import io import numpy as np import torch import torch.nn.functional as F from .. import FairseqDataset from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes from fairseq.data.audio.audio_utils import ( parse_path, read_from_stored_zip, is_sf_audio_data, ) from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel logger = logging.getLogger(__name__) class RawAudioDataset(FairseqDataset): def __init__( self, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, compute_mask_indices=False, **mask_compute_kwargs, ): super().__init__() self.sample_rate = sample_rate self.sizes = [] self.max_sample_size = ( max_sample_size if max_sample_size is not None else sys.maxsize ) self.min_sample_size = min_sample_size self.pad = pad self.shuffle = shuffle self.normalize = normalize self.compute_mask_indices = compute_mask_indices if self.compute_mask_indices: self.mask_compute_kwargs = mask_compute_kwargs self._features_size_map = {} self._C = mask_compute_kwargs["encoder_embed_dim"] self._conv_feature_layers = eval(mask_compute_kwargs["conv_feature_layers"]) def __getitem__(self, index): raise NotImplementedError() def __len__(self): return len(self.sizes) def postprocess(self, feats, curr_sample_rate): if feats.dim() == 2: feats = feats.mean(-1) if curr_sample_rate != self.sample_rate: raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") assert feats.dim() == 1, feats.dim() if self.normalize: with torch.no_grad(): feats = F.layer_norm(feats, feats.shape) return feats def crop_to_max_size(self, wav, target_size): size = len(wav) diff = size - target_size if diff <= 0: return wav start = np.random.randint(0, diff + 1) end = size - diff + start return wav[start:end] def _compute_mask_indices(self, dims, padding_mask): B, T, C = dims mask_indices, mask_channel_indices = None, None if self.mask_compute_kwargs["mask_prob"] > 0: mask_indices = compute_mask_indices( (B, T), padding_mask, self.mask_compute_kwargs["mask_prob"], self.mask_compute_kwargs["mask_length"], self.mask_compute_kwargs["mask_selection"], self.mask_compute_kwargs["mask_other"], min_masks=2, no_overlap=self.mask_compute_kwargs["no_mask_overlap"], min_space=self.mask_compute_kwargs["mask_min_space"], ) mask_indices = torch.from_numpy(mask_indices) if self.mask_compute_kwargs["mask_channel_prob"] > 0: mask_channel_indices = compute_mask_indices( (B, C), None, self.mask_compute_kwargs["mask_channel_prob"], self.mask_compute_kwargs["mask_channel_length"], self.mask_compute_kwargs["mask_channel_selection"], self.mask_compute_kwargs["mask_channel_other"], no_overlap=self.mask_compute_kwargs["no_mask_channel_overlap"], min_space=self.mask_compute_kwargs["mask_channel_min_space"], ) mask_channel_indices = ( torch.from_numpy(mask_channel_indices).unsqueeze(1).expand(-1, T, -1) ) return mask_indices, mask_channel_indices @staticmethod def _bucket_tensor(tensor, num_pad, value): return F.pad(tensor, (0, num_pad), value=value) def collater(self, samples): samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: return {} sources = [s["source"] for s in samples] sizes = [len(s) for s in sources] if self.pad: target_size = min(max(sizes), self.max_sample_size) else: target_size = min(min(sizes), self.max_sample_size) collated_sources = sources[0].new_zeros(len(sources), target_size) padding_mask = ( torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None ) for i, (source, size) in enumerate(zip(sources, sizes)): diff = size - target_size if diff == 0: collated_sources[i] = source elif diff < 0: assert self.pad collated_sources[i] = torch.cat( [source, source.new_full((-diff,), 0.0)] ) padding_mask[i, diff:] = True else: collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask if hasattr(self, "num_buckets") and self.num_buckets > 0: assert self.pad, "Cannot bucket without padding first." bucket = max(self._bucketed_sizes[s["id"]] for s in samples) num_pad = bucket - collated_sources.size(-1) if num_pad: input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) if self.compute_mask_indices: B = input["source"].size(0) T = self._get_mask_indices_dims(input["source"].size(-1)) padding_mask_reshaped = input["padding_mask"].clone() extra = padding_mask_reshaped.size(1) % T if extra > 0: padding_mask_reshaped = padding_mask_reshaped[:, :-extra] padding_mask_reshaped = padding_mask_reshaped.view( padding_mask_reshaped.size(0), T, -1 ) padding_mask_reshaped = padding_mask_reshaped.all(-1) input["padding_count"] = padding_mask_reshaped.sum(-1).max().item() mask_indices, mask_channel_indices = self._compute_mask_indices( (B, T, self._C), padding_mask_reshaped, ) input["mask_indices"] = mask_indices input["mask_channel_indices"] = mask_channel_indices out["sample_size"] = mask_indices.sum().item() out["net_input"] = input return out def _get_mask_indices_dims(self, size, padding=0, dilation=1): if size not in self._features_size_map: L_in = size for (_, kernel_size, stride) in self._conv_feature_layers: L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 L_out = 1 + L_out // stride L_in = L_out self._features_size_map[size] = L_out return self._features_size_map[size] def num_tokens(self, index): return self.size(index) def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" if self.pad: return self.sizes[index] return min(self.sizes[index], self.max_sample_size) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: order = [np.random.permutation(len(self))] order.append( np.minimum( np.array(self.sizes), self.max_sample_size, ) ) return np.lexsort(order)[::-1] else: return np.arange(len(self)) def set_bucket_info(self, num_buckets): self.num_buckets = num_buckets if self.num_buckets > 0: self._collated_sizes = np.minimum( np.array(self.sizes), self.max_sample_size, ) self.buckets = get_buckets( self._collated_sizes, self.num_buckets, ) self._bucketed_sizes = get_bucketed_sizes( self._collated_sizes, self.buckets ) logger.info( f"{len(self.buckets)} bucket(s) for the audio dataset: " f"{self.buckets}" ) class FileAudioDataset(RawAudioDataset): def __init__( self, manifest_path, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, num_buckets=0, compute_mask_indices=False, text_compression_level=TextCompressionLevel.none, **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, pad=pad, normalize=normalize, compute_mask_indices=compute_mask_indices, **mask_compute_kwargs, ) self.text_compressor = TextCompressor(level=text_compression_level) skipped = 0 self.fnames = [] sizes = [] self.skipped_indices = set() with open(manifest_path, "r") as f: self.root_dir = f.readline().strip() for i, line in enumerate(f): items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) if min_sample_size is not None and sz < min_sample_size: skipped += 1 self.skipped_indices.add(i) continue self.fnames.append(self.text_compressor.compress(items[0])) sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") self.sizes = np.array(sizes, dtype=np.int64) try: import pyarrow self.fnames = pyarrow.array(self.fnames) except: logger.debug( "Could not create a pyarrow array. Please install pyarrow for better performance" ) pass self.set_bucket_info(num_buckets) def __getitem__(self, index): import soundfile as sf fn = self.fnames[index] fn = fn if isinstance(self.fnames, list) else fn.as_py() fn = self.text_compressor.decompress(fn) path_or_fp = os.path.join(self.root_dir, fn) _path, slice_ptr = parse_path(path_or_fp) if len(slice_ptr) == 2: byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) assert is_sf_audio_data(byte_data) path_or_fp = io.BytesIO(byte_data) wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats} class BinarizedAudioDataset(RawAudioDataset): def __init__( self, data_dir, split, sample_rate, max_sample_size=None, min_sample_size=0, shuffle=True, pad=False, normalize=False, num_buckets=0, compute_mask_indices=False, **mask_compute_kwargs, ): super().__init__( sample_rate=sample_rate, max_sample_size=max_sample_size, min_sample_size=min_sample_size, shuffle=shuffle, pad=pad, normalize=normalize, compute_mask_indices=compute_mask_indices, **mask_compute_kwargs, ) from fairseq.data import data_utils, Dictionary self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt")) root_path = os.path.join(data_dir, f"{split}.root") if os.path.exists(root_path): with open(root_path, "r") as f: self.root_dir = next(f).strip() else: self.root_dir = None fnames_path = os.path.join(data_dir, split) self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict) lengths_path = os.path.join(data_dir, f"{split}.lengths") with open(lengths_path, "r") as f: for line in f: sz = int(line.rstrip()) assert ( sz >= min_sample_size ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}" self.sizes.append(sz) self.sizes = np.array(self.sizes, dtype=np.int64) self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)} samples") def __getitem__(self, index): import soundfile as sf fname = self.fnames_dict.string(self.fnames[index], separator="") if self.root_dir: fname = os.path.join(self.root_dir, fname) wav, curr_sample_rate = sf.read(fname) feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats}