File size: 8,270 Bytes
ad16788 |
|
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.tts.abs_tts import AbsTTS
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetTTSModel(AbsESPnetModel):
def __init__(
self,
feats_extract: Optional[AbsFeatsExtract],
pitch_extract: Optional[AbsFeatsExtract],
energy_extract: Optional[AbsFeatsExtract],
normalize: Optional[AbsNormalize and InversibleInterface],
pitch_normalize: Optional[AbsNormalize and InversibleInterface],
energy_normalize: Optional[AbsNormalize and InversibleInterface],
tts: AbsTTS,
):
assert check_argument_types()
super().__init__()
self.feats_extract = feats_extract
self.pitch_extract = pitch_extract
self.energy_extract = energy_extract
self.normalize = normalize
self.pitch_normalize = pitch_normalize
self.energy_normalize = energy_normalize
self.tts = tts
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
pitch: torch.Tensor = None,
pitch_lengths: torch.Tensor = None,
energy: torch.Tensor = None,
energy_lengths: torch.Tensor = None,
spembs: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
with autocast(False):
# Extract features
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
# Extract auxiliary features
if self.pitch_extract is not None and pitch is None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None and energy is None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
# Normalize
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
if self.pitch_normalize is not None:
pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths)
if self.energy_normalize is not None:
energy, energy_lengths = self.energy_normalize(energy, energy_lengths)
# Update kwargs for additional auxiliary inputs
if spembs is not None:
kwargs.update(spembs=spembs)
if durations is not None:
kwargs.update(durations=durations, durations_lengths=durations_lengths)
if self.pitch_extract is not None and pitch is not None:
kwargs.update(pitch=pitch, pitch_lengths=pitch_lengths)
if self.energy_extract is not None and energy is not None:
kwargs.update(energy=energy, energy_lengths=energy_lengths)
return self.tts(
text=text,
text_lengths=text_lengths,
speech=feats,
speech_lengths=feats_lengths,
**kwargs,
)
def collect_feats(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
pitch: torch.Tensor = None,
pitch_lengths: torch.Tensor = None,
energy: torch.Tensor = None,
energy_lengths: torch.Tensor = None,
spembs: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
feats_dict = {"feats": feats, "feats_lengths": feats_lengths}
if self.pitch_extract is not None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if pitch is not None:
feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths)
if energy is not None:
feats_dict.update(energy=energy, energy_lengths=energy_lengths)
return feats_dict
def inference(
self,
text: torch.Tensor,
speech: torch.Tensor = None,
spembs: torch.Tensor = None,
durations: torch.Tensor = None,
pitch: torch.Tensor = None,
energy: torch.Tensor = None,
**decode_config,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
kwargs = {}
# TC marker, oorspr false
if decode_config["use_teacher_forcing"] or getattr(self.tts, "use_gst", False):
if speech is None:
raise RuntimeError("missing required argument: 'speech'")
if self.feats_extract is not None:
feats = self.feats_extract(speech[None])[0][0]
else:
feats = speech
if self.normalize is not None:
feats = self.normalize(feats[None])[0][0]
kwargs["speech"] = feats
if decode_config["use_teacher_forcing"]:
if durations is not None:
kwargs["durations"] = durations
if self.pitch_extract is not None:
pitch = self.pitch_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.pitch_normalize is not None:
pitch = self.pitch_normalize(pitch[None])[0][0]
if pitch is not None:
kwargs["pitch"] = pitch
if self.energy_extract is not None:
energy = self.energy_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.energy_normalize is not None:
energy = self.energy_normalize(energy[None])[0][0]
if energy is not None:
kwargs["energy"] = energy
if spembs is not None:
kwargs["spembs"] = spembs
outs, probs, att_ws, ref_embs, ar_prior_loss = self.tts.inference(
text=text,
**kwargs,
**decode_config
)
if self.normalize is not None:
# NOTE: normalize.inverse is in-place operation
outs_denorm = self.normalize.inverse(outs.clone()[None])[0][0]
else:
outs_denorm = outs
return outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss
|