torchaudio_tacotron2_kss / processing_tacotron2.py
Bingsu's picture
Upload 3 files
d07276d
import copy
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torchaudio.transforms import MelSpectrogram
from transformers import Wav2Vec2PhonemeCTCTokenizer
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
AudioType = Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]]
class Tacotron2FeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["mel_specgram", "mel_specgram_length", "gate_padded"]
def __init__(
self,
feature_size: int = 80, # n_mels
sampling_rate: int = 22050,
n_fft: int = 1024,
hop_length: int = 256,
win_length: int = 1024,
mel_fmin: float = 0.0,
mel_fmax: float = 8000.0,
padding_value: float = 0.0,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
**kwargs,
)
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
def mel_specgram(self, waveform: torch.Tensor) -> torch.Tensor:
if not hasattr(self, "_mel_specgram"):
self._mel_specgram = MelSpectrogram(
sample_rate=self.sampling_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
f_min=self.mel_fmin,
f_max=self.mel_fmax,
n_mels=self.feature_size,
mel_scale="slaney",
normalized=False,
power=1,
norm="slaney",
)
melspectrogram = self._mel_specgram(waveform)
# spectral normalization
output = torch.log(torch.clamp(melspectrogram, min=1e-5))
# transpose for padding
return output.permute(1, 0)
def __call__(
self,
audio: AudioType,
sampling_rate: Optional[int] = None,
padding: Union[bool, str] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
return_length: bool = False,
return_gate_padded: bool = False,
**kwargs,
) -> BatchFeature:
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `audio` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(audio, (list, tuple))
and (
isinstance(audio[0], np.ndarray) or isinstance(audio[0], (tuple, list))
)
)
if is_batched:
audio = [np.asarray(speech, dtype=np.float32) for speech in audio]
elif not is_batched and not isinstance(audio, np.ndarray):
audio = np.asarray(audio, dtype=np.float32)
elif isinstance(audio, np.ndarray) and audio.dtype is np.dtype(np.float64):
audio = audio.astype(np.float32)
# always return batch
if not is_batched:
audio = [audio]
features = [
self.mel_specgram(torch.from_numpy(one_waveform)).numpy()
for one_waveform in audio
]
encoded_inputs = BatchFeature({"mel_specgram": features})
padded_inputs = self.pad(
encoded_inputs,
padding=padding,
return_attention_mask=return_gate_padded,
**kwargs,
)
if return_length:
mel_specgram_length = [mel.shape[0] for mel in features]
if len(mel_specgram_length) == 1 and return_tensors is None:
mel_specgram_length = mel_specgram_length[0]
padded_inputs["mel_specgram_length"] = mel_specgram_length
if return_gate_padded:
gate_padded = 1 - padded_inputs.pop("attention_mask")
gate_padded = np.roll(gate_padded, -1, axis=1)
gate_padded[:, -1] = 1
gate_padded = gate_padded.astype(np.float32)
padded_inputs["gate_padded"] = gate_padded
mel_specgram = padded_inputs["mel_specgram"]
if isinstance(mel_specgram[0], list):
padded_inputs["mel_specgram"] = [
np.asarray(feature, dtype=np.float32) for feature in mel_specgram
]
padded_inputs["mel_specgram"] = [
spec.transpose(1, 0) for spec in padded_inputs["mel_specgram"]
]
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
output.pop("_mel_specgram", None)
return output
class Tacotron2Processor(ProcessorMixin):
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "Wav2Vec2PhonemeCTCTokenizer"
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
def __call__(
self,
text: Optional[str] = None,
audio: Optional[AudioType] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_length: bool = True,
**kwargs,
) -> Any:
if text is None and audio is None:
raise ValueError(
"You have to specify either text or audio. Both cannot be none."
)
if text is not None:
encoding = self.tokenizer(
text,
return_tensors=return_tensors,
padding=True,
return_attention_mask=False,
return_length=return_length,
)
if audio is not None:
features = self.feature_extractor(
audio,
return_tensors=return_tensors,
return_length=return_length,
**kwargs,
)
if text is not None and audio is not None:
return BatchFeature({**features, **encoding})
elif text is not None:
return encoding
else:
return features
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)