import copy from typing import Any, Dict, List, Optional, Union import numpy as np import torch from torchaudio.transforms import MelSpectrogram from transformers import Wav2Vec2PhonemeCTCTokenizer from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) AudioType = Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]] class Tacotron2FeatureExtractor(SequenceFeatureExtractor): model_input_names = ["mel_specgram", "mel_specgram_length", "gate_padded"] def __init__( self, feature_size: int = 80, # n_mels sampling_rate: int = 22050, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, mel_fmin: float = 0.0, mel_fmax: float = 8000.0, padding_value: float = 0.0, **kwargs, ): super().__init__( feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs, ) self.feature_size = feature_size self.sampling_rate = sampling_rate self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax def mel_specgram(self, waveform: torch.Tensor) -> torch.Tensor: if not hasattr(self, "_mel_specgram"): self._mel_specgram = MelSpectrogram( sample_rate=self.sampling_rate, n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, f_min=self.mel_fmin, f_max=self.mel_fmax, n_mels=self.feature_size, mel_scale="slaney", normalized=False, power=1, norm="slaney", ) melspectrogram = self._mel_specgram(waveform) # spectral normalization output = torch.log(torch.clamp(melspectrogram, min=1e-5)) # transpose for padding return output.permute(1, 0) def __call__( self, audio: AudioType, sampling_rate: Optional[int] = None, padding: Union[bool, str] = True, return_tensors: Optional[Union[str, TensorType]] = None, return_length: bool = False, return_gate_padded: bool = False, **kwargs, ) -> BatchFeature: if sampling_rate is not None: if sampling_rate != self.sampling_rate: raise ValueError( f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" f" {self.sampling_rate}. Please make sure that the provided `audio` input was sampled with" f" {self.sampling_rate} and not {sampling_rate}." ) else: logger.warning( "It is strongly recommended to pass the `sampling_rate` argument to this function. " "Failing to do so can result in silent errors that might be hard to debug." ) is_batched = bool( isinstance(audio, (list, tuple)) and ( isinstance(audio[0], np.ndarray) or isinstance(audio[0], (tuple, list)) ) ) if is_batched: audio = [np.asarray(speech, dtype=np.float32) for speech in audio] elif not is_batched and not isinstance(audio, np.ndarray): audio = np.asarray(audio, dtype=np.float32) elif isinstance(audio, np.ndarray) and audio.dtype is np.dtype(np.float64): audio = audio.astype(np.float32) # always return batch if not is_batched: audio = [audio] features = [ self.mel_specgram(torch.from_numpy(one_waveform)).numpy() for one_waveform in audio ] encoded_inputs = BatchFeature({"mel_specgram": features}) padded_inputs = self.pad( encoded_inputs, padding=padding, return_attention_mask=return_gate_padded, **kwargs, ) if return_length: mel_specgram_length = [mel.shape[0] for mel in features] if len(mel_specgram_length) == 1 and return_tensors is None: mel_specgram_length = mel_specgram_length[0] padded_inputs["mel_specgram_length"] = mel_specgram_length if return_gate_padded: gate_padded = 1 - padded_inputs.pop("attention_mask") gate_padded = np.roll(gate_padded, -1, axis=1) gate_padded[:, -1] = 1 gate_padded = gate_padded.astype(np.float32) padded_inputs["gate_padded"] = gate_padded mel_specgram = padded_inputs["mel_specgram"] if isinstance(mel_specgram[0], list): padded_inputs["mel_specgram"] = [ np.asarray(feature, dtype=np.float32) for feature in mel_specgram ] padded_inputs["mel_specgram"] = [ spec.transpose(1, 0) for spec in padded_inputs["mel_specgram"] ] if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) return padded_inputs def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. """ output = copy.deepcopy(self.__dict__) output["feature_extractor_type"] = self.__class__.__name__ output.pop("_mel_specgram", None) return output class Tacotron2Processor(ProcessorMixin): feature_extractor_class = "AutoFeatureExtractor" tokenizer_class = "Wav2Vec2PhonemeCTCTokenizer" def __init__(self, feature_extractor, tokenizer): self.feature_extractor = feature_extractor self.tokenizer = tokenizer self.current_processor = self.feature_extractor def __call__( self, text: Optional[str] = None, audio: Optional[AudioType] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_length: bool = True, **kwargs, ) -> Any: if text is None and audio is None: raise ValueError( "You have to specify either text or audio. Both cannot be none." ) if text is not None: encoding = self.tokenizer( text, return_tensors=return_tensors, padding=True, return_attention_mask=False, return_length=return_length, ) if audio is not None: features = self.feature_extractor( audio, return_tensors=return_tensors, return_length=return_length, **kwargs, ) if text is not None and audio is not None: return BatchFeature({**features, **encoding}) elif text is not None: return encoding else: return features def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs)