|
import copy |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torchaudio.transforms import MelSpectrogram |
|
from transformers import Wav2Vec2PhonemeCTCTokenizer |
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.utils import TensorType, logging |
|
|
|
logger = logging.get_logger(__name__) |
|
AudioType = Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]] |
|
|
|
|
|
class Tacotron2FeatureExtractor(SequenceFeatureExtractor): |
|
model_input_names = ["mel_specgram", "mel_specgram_length", "gate_padded"] |
|
|
|
def __init__( |
|
self, |
|
feature_size: int = 80, |
|
sampling_rate: int = 22050, |
|
n_fft: int = 1024, |
|
hop_length: int = 256, |
|
win_length: int = 1024, |
|
mel_fmin: float = 0.0, |
|
mel_fmax: float = 8000.0, |
|
padding_value: float = 0.0, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
feature_size=feature_size, |
|
sampling_rate=sampling_rate, |
|
padding_value=padding_value, |
|
**kwargs, |
|
) |
|
self.feature_size = feature_size |
|
self.sampling_rate = sampling_rate |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.mel_fmin = mel_fmin |
|
self.mel_fmax = mel_fmax |
|
|
|
def mel_specgram(self, waveform: torch.Tensor) -> torch.Tensor: |
|
if not hasattr(self, "_mel_specgram"): |
|
self._mel_specgram = MelSpectrogram( |
|
sample_rate=self.sampling_rate, |
|
n_fft=self.n_fft, |
|
win_length=self.win_length, |
|
hop_length=self.hop_length, |
|
f_min=self.mel_fmin, |
|
f_max=self.mel_fmax, |
|
n_mels=self.feature_size, |
|
mel_scale="slaney", |
|
normalized=False, |
|
power=1, |
|
norm="slaney", |
|
) |
|
melspectrogram = self._mel_specgram(waveform) |
|
|
|
output = torch.log(torch.clamp(melspectrogram, min=1e-5)) |
|
|
|
|
|
return output.permute(1, 0) |
|
|
|
def __call__( |
|
self, |
|
audio: AudioType, |
|
sampling_rate: Optional[int] = None, |
|
padding: Union[bool, str] = True, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
return_length: bool = False, |
|
return_gate_padded: bool = False, |
|
**kwargs, |
|
) -> BatchFeature: |
|
|
|
if sampling_rate is not None: |
|
if sampling_rate != self.sampling_rate: |
|
raise ValueError( |
|
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" |
|
f" {self.sampling_rate}. Please make sure that the provided `audio` input was sampled with" |
|
f" {self.sampling_rate} and not {sampling_rate}." |
|
) |
|
|
|
else: |
|
logger.warning( |
|
"It is strongly recommended to pass the `sampling_rate` argument to this function. " |
|
"Failing to do so can result in silent errors that might be hard to debug." |
|
) |
|
|
|
is_batched = bool( |
|
isinstance(audio, (list, tuple)) |
|
and ( |
|
isinstance(audio[0], np.ndarray) or isinstance(audio[0], (tuple, list)) |
|
) |
|
) |
|
|
|
if is_batched: |
|
audio = [np.asarray(speech, dtype=np.float32) for speech in audio] |
|
elif not is_batched and not isinstance(audio, np.ndarray): |
|
audio = np.asarray(audio, dtype=np.float32) |
|
elif isinstance(audio, np.ndarray) and audio.dtype is np.dtype(np.float64): |
|
audio = audio.astype(np.float32) |
|
|
|
|
|
if not is_batched: |
|
audio = [audio] |
|
|
|
features = [ |
|
self.mel_specgram(torch.from_numpy(one_waveform)).numpy() |
|
for one_waveform in audio |
|
] |
|
|
|
encoded_inputs = BatchFeature({"mel_specgram": features}) |
|
|
|
padded_inputs = self.pad( |
|
encoded_inputs, |
|
padding=padding, |
|
return_attention_mask=return_gate_padded, |
|
**kwargs, |
|
) |
|
|
|
if return_length: |
|
mel_specgram_length = [mel.shape[0] for mel in features] |
|
if len(mel_specgram_length) == 1 and return_tensors is None: |
|
mel_specgram_length = mel_specgram_length[0] |
|
padded_inputs["mel_specgram_length"] = mel_specgram_length |
|
|
|
if return_gate_padded: |
|
gate_padded = 1 - padded_inputs.pop("attention_mask") |
|
gate_padded = np.roll(gate_padded, -1, axis=1) |
|
gate_padded[:, -1] = 1 |
|
gate_padded = gate_padded.astype(np.float32) |
|
padded_inputs["gate_padded"] = gate_padded |
|
|
|
mel_specgram = padded_inputs["mel_specgram"] |
|
if isinstance(mel_specgram[0], list): |
|
padded_inputs["mel_specgram"] = [ |
|
np.asarray(feature, dtype=np.float32) for feature in mel_specgram |
|
] |
|
|
|
padded_inputs["mel_specgram"] = [ |
|
spec.transpose(1, 0) for spec in padded_inputs["mel_specgram"] |
|
] |
|
|
|
if return_tensors is not None: |
|
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) |
|
|
|
return padded_inputs |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. |
|
Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output["feature_extractor_type"] = self.__class__.__name__ |
|
output.pop("_mel_specgram", None) |
|
|
|
return output |
|
|
|
|
|
class Tacotron2Processor(ProcessorMixin): |
|
feature_extractor_class = "AutoFeatureExtractor" |
|
tokenizer_class = "Wav2Vec2PhonemeCTCTokenizer" |
|
|
|
def __init__(self, feature_extractor, tokenizer): |
|
self.feature_extractor = feature_extractor |
|
self.tokenizer = tokenizer |
|
self.current_processor = self.feature_extractor |
|
|
|
def __call__( |
|
self, |
|
text: Optional[str] = None, |
|
audio: Optional[AudioType] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
return_length: bool = True, |
|
**kwargs, |
|
) -> Any: |
|
if text is None and audio is None: |
|
raise ValueError( |
|
"You have to specify either text or audio. Both cannot be none." |
|
) |
|
|
|
if text is not None: |
|
encoding = self.tokenizer( |
|
text, |
|
return_tensors=return_tensors, |
|
padding=True, |
|
return_attention_mask=False, |
|
return_length=return_length, |
|
) |
|
|
|
if audio is not None: |
|
features = self.feature_extractor( |
|
audio, |
|
return_tensors=return_tensors, |
|
return_length=return_length, |
|
**kwargs, |
|
) |
|
|
|
if text is not None and audio is not None: |
|
return BatchFeature({**features, **encoding}) |
|
elif text is not None: |
|
return encoding |
|
else: |
|
return features |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer |
|
to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|