|
|
|
|
|
|
|
"""Style encoder of GST-Tacotron.""" |
|
|
|
from typeguard import check_argument_types |
|
from typing import Sequence |
|
|
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.transformer.attention import ( |
|
MultiHeadedAttention as BaseMultiHeadedAttention, |
|
) |
|
|
|
|
|
class StyleEncoder(torch.nn.Module): |
|
"""Style encoder. |
|
|
|
This module is style encoder introduced in `Style Tokens: Unsupervised Style |
|
Modeling, Control and Transfer in End-to-End Speech Synthesis`. |
|
|
|
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End |
|
Speech Synthesis`: https://arxiv.org/abs/1803.09017 |
|
|
|
Args: |
|
idim (int, optional): Dimension of the input mel-spectrogram. |
|
gst_tokens (int, optional): The number of GST embeddings. |
|
gst_token_dim (int, optional): Dimension of each GST embedding. |
|
gst_heads (int, optional): The number of heads in GST multihead attention. |
|
conv_layers (int, optional): The number of conv layers in the reference encoder. |
|
conv_chans_list: (Sequence[int], optional): |
|
List of the number of channels of conv layers in the referece encoder. |
|
conv_kernel_size (int, optional): |
|
Kernal size of conv layers in the reference encoder. |
|
conv_stride (int, optional): |
|
Stride size of conv layers in the reference encoder. |
|
gru_layers (int, optional): The number of GRU layers in the reference encoder. |
|
gru_units (int, optional): The number of GRU units in the reference encoder. |
|
|
|
Todo: |
|
* Support manual weight specification in inference. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim: int = 80, |
|
gst_tokens: int = 10, |
|
gst_token_dim: int = 256, |
|
gst_heads: int = 4, |
|
conv_layers: int = 6, |
|
conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), |
|
conv_kernel_size: int = 3, |
|
conv_stride: int = 2, |
|
gru_layers: int = 1, |
|
gru_units: int = 128, |
|
): |
|
"""Initilize global style encoder module.""" |
|
assert check_argument_types() |
|
super(StyleEncoder, self).__init__() |
|
|
|
self.ref_enc = ReferenceEncoder( |
|
idim=idim, |
|
conv_layers=conv_layers, |
|
conv_chans_list=conv_chans_list, |
|
conv_kernel_size=conv_kernel_size, |
|
conv_stride=conv_stride, |
|
gru_layers=gru_layers, |
|
gru_units=gru_units, |
|
) |
|
self.stl = StyleTokenLayer( |
|
ref_embed_dim=gru_units, |
|
gst_tokens=gst_tokens, |
|
gst_token_dim=gst_token_dim, |
|
gst_heads=gst_heads, |
|
) |
|
|
|
def forward(self, speech: torch.Tensor) -> torch.Tensor: |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
speech (Tensor): Batch of padded target features (B, Lmax, odim). |
|
|
|
Returns: |
|
Tensor: Style token embeddings (B, token_dim). |
|
|
|
""" |
|
ref_embs = self.ref_enc(speech) |
|
style_embs = self.stl(ref_embs) |
|
|
|
return style_embs |
|
|
|
|
|
class ReferenceEncoder(torch.nn.Module): |
|
"""Reference encoder module. |
|
|
|
This module is refernece encoder introduced in `Style Tokens: Unsupervised Style |
|
Modeling, Control and Transfer in End-to-End Speech Synthesis`. |
|
|
|
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End |
|
Speech Synthesis`: https://arxiv.org/abs/1803.09017 |
|
|
|
Args: |
|
idim (int, optional): Dimension of the input mel-spectrogram. |
|
conv_layers (int, optional): The number of conv layers in the reference encoder. |
|
conv_chans_list: (Sequence[int], optional): |
|
List of the number of channels of conv layers in the referece encoder. |
|
conv_kernel_size (int, optional): |
|
Kernal size of conv layers in the reference encoder. |
|
conv_stride (int, optional): |
|
Stride size of conv layers in the reference encoder. |
|
gru_layers (int, optional): The number of GRU layers in the reference encoder. |
|
gru_units (int, optional): The number of GRU units in the reference encoder. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim=80, |
|
conv_layers: int = 6, |
|
conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), |
|
conv_kernel_size: int = 3, |
|
conv_stride: int = 2, |
|
gru_layers: int = 1, |
|
gru_units: int = 128, |
|
): |
|
"""Initilize reference encoder module.""" |
|
assert check_argument_types() |
|
super(ReferenceEncoder, self).__init__() |
|
|
|
|
|
assert conv_kernel_size % 2 == 1, "kernel size must be odd." |
|
assert ( |
|
len(conv_chans_list) == conv_layers |
|
), "the number of conv layers and length of channels list must be the same." |
|
|
|
convs = [] |
|
padding = (conv_kernel_size - 1) // 2 |
|
for i in range(conv_layers): |
|
conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1] |
|
conv_out_chans = conv_chans_list[i] |
|
convs += [ |
|
torch.nn.Conv2d( |
|
conv_in_chans, |
|
conv_out_chans, |
|
kernel_size=conv_kernel_size, |
|
stride=conv_stride, |
|
padding=padding, |
|
|
|
bias=False, |
|
), |
|
torch.nn.BatchNorm2d(conv_out_chans), |
|
torch.nn.ReLU(inplace=True), |
|
] |
|
self.convs = torch.nn.Sequential(*convs) |
|
|
|
self.conv_layers = conv_layers |
|
self.kernel_size = conv_kernel_size |
|
self.stride = conv_stride |
|
self.padding = padding |
|
|
|
|
|
gru_in_units = idim |
|
for i in range(conv_layers): |
|
gru_in_units = ( |
|
gru_in_units - conv_kernel_size + 2 * padding |
|
) // conv_stride + 1 |
|
gru_in_units *= conv_out_chans |
|
self.gru = torch.nn.GRU(gru_in_units, gru_units, gru_layers, batch_first=True) |
|
|
|
def forward(self, speech: torch.Tensor) -> torch.Tensor: |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
speech (Tensor): Batch of padded target features (B, Lmax, idim). |
|
|
|
Returns: |
|
Tensor: Reference embedding (B, gru_units) |
|
|
|
""" |
|
batch_size = speech.size(0) |
|
xs = speech.unsqueeze(1) |
|
hs = self.convs(xs).transpose(1, 2) |
|
|
|
time_length = hs.size(1) |
|
hs = hs.contiguous().view(batch_size, time_length, -1) |
|
self.gru.flatten_parameters() |
|
_, ref_embs = self.gru(hs) |
|
ref_embs = ref_embs[-1] |
|
|
|
return ref_embs |
|
|
|
|
|
class StyleTokenLayer(torch.nn.Module): |
|
"""Style token layer module. |
|
|
|
This module is style token layer introduced in `Style Tokens: Unsupervised Style |
|
Modeling, Control and Transfer in End-to-End Speech Synthesis`. |
|
|
|
.. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End |
|
Speech Synthesis`: https://arxiv.org/abs/1803.09017 |
|
|
|
Args: |
|
ref_embed_dim (int, optional): Dimension of the input reference embedding. |
|
gst_tokens (int, optional): The number of GST embeddings. |
|
gst_token_dim (int, optional): Dimension of each GST embedding. |
|
gst_heads (int, optional): The number of heads in GST multihead attention. |
|
dropout_rate (float, optional): Dropout rate in multi-head attention. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
ref_embed_dim: int = 128, |
|
gst_tokens: int = 10, |
|
gst_token_dim: int = 256, |
|
gst_heads: int = 4, |
|
dropout_rate: float = 0.0, |
|
): |
|
"""Initilize style token layer module.""" |
|
assert check_argument_types() |
|
super(StyleTokenLayer, self).__init__() |
|
|
|
gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads) |
|
self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs)) |
|
self.mha = MultiHeadedAttention( |
|
q_dim=ref_embed_dim, |
|
k_dim=gst_token_dim // gst_heads, |
|
v_dim=gst_token_dim // gst_heads, |
|
n_head=gst_heads, |
|
n_feat=gst_token_dim, |
|
dropout_rate=dropout_rate, |
|
) |
|
|
|
def forward(self, ref_embs: torch.Tensor) -> torch.Tensor: |
|
"""Calculate forward propagation. |
|
|
|
Args: |
|
ref_embs (Tensor): Reference embeddings (B, ref_embed_dim). |
|
|
|
Returns: |
|
Tensor: Style token embeddings (B, gst_token_dim). |
|
|
|
""" |
|
batch_size = ref_embs.size(0) |
|
|
|
gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
ref_embs = ref_embs.unsqueeze(1) |
|
style_embs = self.mha(ref_embs, gst_embs, gst_embs, None) |
|
|
|
return style_embs.squeeze(1) |
|
|
|
|
|
class MultiHeadedAttention(BaseMultiHeadedAttention): |
|
"""Multi head attention module with different input dimension.""" |
|
|
|
def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): |
|
"""Initialize multi head attention module.""" |
|
|
|
|
|
torch.nn.Module.__init__(self) |
|
assert n_feat % n_head == 0 |
|
|
|
self.d_k = n_feat // n_head |
|
self.h = n_head |
|
self.linear_q = torch.nn.Linear(q_dim, n_feat) |
|
self.linear_k = torch.nn.Linear(k_dim, n_feat) |
|
self.linear_v = torch.nn.Linear(v_dim, n_feat) |
|
self.linear_out = torch.nn.Linear(n_feat, n_feat) |
|
self.attn = None |
|
self.dropout = torch.nn.Dropout(p=dropout_rate) |
|
|