Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder | |
from .wavenet import WaveNet | |
def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]: | |
return { | |
"idim": decoder.conv_out.out_channels, | |
"odim": decoder.conv_in[0].in_channels, | |
"n_layer": len(decoder.decoder_block), | |
"bn_dim": decoder.conv_in[0].out_channels, | |
"hidden": decoder.conv_in[2].out_channels, | |
"kernel": decoder.decoder_block[0].dwconv.kernel_size[0], | |
"dilation": decoder.decoder_block[0].dwconv.dilation[0], | |
"down": decoder.up, | |
} | |
class DVAEEncoder(nn.Module): | |
def __init__( | |
self, | |
idim: int, | |
odim: int, | |
n_layer: int = 12, | |
bn_dim: int = 64, | |
hidden: int = 256, | |
kernel: int = 7, | |
dilation: int = 2, | |
down: bool = False, | |
) -> None: | |
super().__init__() | |
self.wavenet = WaveNet( | |
input_channels=100, | |
residual_channels=idim, | |
residual_layers=20, | |
dilation_cycle=4, | |
) | |
self.conv_in_transpose = nn.ConvTranspose1d( | |
idim, hidden, kernel_size=1, bias=False | |
) | |
# nn.Sequential( | |
# nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False), | |
# nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False) | |
# ) | |
self.encoder_block = nn.ModuleList( | |
[ | |
ConvNeXtBlock( | |
hidden, | |
hidden * 4, | |
kernel, | |
dilation, | |
) | |
for _ in range(n_layer) | |
] | |
) | |
self.conv_out_transpose = nn.Sequential( | |
nn.Conv1d(hidden, bn_dim, 3, 1, 1), | |
nn.GELU(), | |
nn.Conv1d(bn_dim, odim, 3, 1, 1), | |
) | |
def forward( | |
self, | |
audio_mel_specs: torch.Tensor, # (batch_size, audio_len*2, 100) | |
audio_attention_mask: torch.Tensor, # (batch_size, audio_len) | |
conditioning=None, | |
) -> torch.Tensor: | |
mel_attention_mask = ( | |
audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1) | |
) | |
x: torch.Tensor = self.wavenet( | |
audio_mel_specs.transpose(1, 2) | |
) # (batch_size, idim, audio_len*2) | |
x = x * mel_attention_mask.unsqueeze(1) | |
x = self.conv_in_transpose(x) # (batch_size, hidden, audio_len*2) | |
for f in self.encoder_block: | |
x = f(x, conditioning) | |
x = self.conv_out_transpose(x) # (batch_size, odim, audio_len*2) | |
x = ( | |
x.view(x.size(0), x.size(1), 2, x.size(2) // 2) | |
.permute(0, 3, 1, 2) | |
.flatten(2) | |
) | |
return x # (batch_size, audio_len, audio_dim=odim*2) | |