File size: 8,270 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.layers.inversible_interface import InversibleInterface
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet2.tts.abs_tts import AbsTTS
from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetTTSModel(AbsESPnetModel):
def __init__(
self,
feats_extract: Optional[AbsFeatsExtract],
pitch_extract: Optional[AbsFeatsExtract],
energy_extract: Optional[AbsFeatsExtract],
normalize: Optional[AbsNormalize and InversibleInterface],
pitch_normalize: Optional[AbsNormalize and InversibleInterface],
energy_normalize: Optional[AbsNormalize and InversibleInterface],
tts: AbsTTS,
):
assert check_argument_types()
super().__init__()
self.feats_extract = feats_extract
self.pitch_extract = pitch_extract
self.energy_extract = energy_extract
self.normalize = normalize
self.pitch_normalize = pitch_normalize
self.energy_normalize = energy_normalize
self.tts = tts
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
pitch: torch.Tensor = None,
pitch_lengths: torch.Tensor = None,
energy: torch.Tensor = None,
energy_lengths: torch.Tensor = None,
spembs: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
with autocast(False):
# Extract features
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
# Extract auxiliary features
if self.pitch_extract is not None and pitch is None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None and energy is None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
# Normalize
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
if self.pitch_normalize is not None:
pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths)
if self.energy_normalize is not None:
energy, energy_lengths = self.energy_normalize(energy, energy_lengths)
# Update kwargs for additional auxiliary inputs
if spembs is not None:
kwargs.update(spembs=spembs)
if durations is not None:
kwargs.update(durations=durations, durations_lengths=durations_lengths)
if self.pitch_extract is not None and pitch is not None:
kwargs.update(pitch=pitch, pitch_lengths=pitch_lengths)
if self.energy_extract is not None and energy is not None:
kwargs.update(energy=energy, energy_lengths=energy_lengths)
return self.tts(
text=text,
text_lengths=text_lengths,
speech=feats,
speech_lengths=feats_lengths,
**kwargs,
)
def collect_feats(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: torch.Tensor = None,
durations_lengths: torch.Tensor = None,
pitch: torch.Tensor = None,
pitch_lengths: torch.Tensor = None,
energy: torch.Tensor = None,
energy_lengths: torch.Tensor = None,
spembs: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
if self.feats_extract is not None:
feats, feats_lengths = self.feats_extract(speech, speech_lengths)
else:
feats, feats_lengths = speech, speech_lengths
feats_dict = {"feats": feats, "feats_lengths": feats_lengths}
if self.pitch_extract is not None:
pitch, pitch_lengths = self.pitch_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if self.energy_extract is not None:
energy, energy_lengths = self.energy_extract(
speech,
speech_lengths,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
)
if pitch is not None:
feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths)
if energy is not None:
feats_dict.update(energy=energy, energy_lengths=energy_lengths)
return feats_dict
def inference(
self,
text: torch.Tensor,
speech: torch.Tensor = None,
spembs: torch.Tensor = None,
durations: torch.Tensor = None,
pitch: torch.Tensor = None,
energy: torch.Tensor = None,
**decode_config,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
kwargs = {}
# TC marker, oorspr false
if decode_config["use_teacher_forcing"] or getattr(self.tts, "use_gst", False):
if speech is None:
raise RuntimeError("missing required argument: 'speech'")
if self.feats_extract is not None:
feats = self.feats_extract(speech[None])[0][0]
else:
feats = speech
if self.normalize is not None:
feats = self.normalize(feats[None])[0][0]
kwargs["speech"] = feats
if decode_config["use_teacher_forcing"]:
if durations is not None:
kwargs["durations"] = durations
if self.pitch_extract is not None:
pitch = self.pitch_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.pitch_normalize is not None:
pitch = self.pitch_normalize(pitch[None])[0][0]
if pitch is not None:
kwargs["pitch"] = pitch
if self.energy_extract is not None:
energy = self.energy_extract(
speech[None],
feats_lengths=torch.LongTensor([len(feats)]),
durations=durations[None],
)[0][0]
if self.energy_normalize is not None:
energy = self.energy_normalize(energy[None])[0][0]
if energy is not None:
kwargs["energy"] = energy
if spembs is not None:
kwargs["spembs"] = spembs
outs, probs, att_ws, ref_embs, ar_prior_loss = self.tts.inference(
text=text,
**kwargs,
**decode_config
)
if self.normalize is not None:
# NOTE: normalize.inverse is in-place operation
outs_denorm = self.normalize.inverse(outs.clone()[None])[0][0]
else:
outs_denorm = outs
return outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss
|