import torch import torch.nn as nn from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder from .wavenet import WaveNet def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]: return { "idim": decoder.conv_out.out_channels, "odim": decoder.conv_in[0].in_channels, "n_layer": len(decoder.decoder_block), "bn_dim": decoder.conv_in[0].out_channels, "hidden": decoder.conv_in[2].out_channels, "kernel": decoder.decoder_block[0].dwconv.kernel_size[0], "dilation": decoder.decoder_block[0].dwconv.dilation[0], "down": decoder.up, } class DVAEEncoder(nn.Module): def __init__( self, idim: int, odim: int, n_layer: int = 12, bn_dim: int = 64, hidden: int = 256, kernel: int = 7, dilation: int = 2, down: bool = False, ) -> None: super().__init__() self.wavenet = WaveNet( input_channels=100, residual_channels=idim, residual_layers=20, dilation_cycle=4, ) self.conv_in_transpose = nn.ConvTranspose1d( idim, hidden, kernel_size=1, bias=False ) # nn.Sequential( # nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False), # nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False) # ) self.encoder_block = nn.ModuleList( [ ConvNeXtBlock( hidden, hidden * 4, kernel, dilation, ) for _ in range(n_layer) ] ) self.conv_out_transpose = nn.Sequential( nn.Conv1d(hidden, bn_dim, 3, 1, 1), nn.GELU(), nn.Conv1d(bn_dim, odim, 3, 1, 1), ) def forward( self, audio_mel_specs: torch.Tensor, # (batch_size, audio_len*2, 100) audio_attention_mask: torch.Tensor, # (batch_size, audio_len) conditioning=None, ) -> torch.Tensor: mel_attention_mask = ( audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1) ) x: torch.Tensor = self.wavenet( audio_mel_specs.transpose(1, 2) ) # (batch_size, idim, audio_len*2) x = x * mel_attention_mask.unsqueeze(1) x = self.conv_in_transpose(x) # (batch_size, hidden, audio_len*2) for f in self.encoder_block: x = f(x, conditioning) x = self.conv_out_transpose(x) # (batch_size, odim, audio_len*2) x = ( x.view(x.size(0), x.size(1), 2, x.size(2) // 2) .permute(0, 3, 1, 2) .flatten(2) ) return x # (batch_size, audio_len, audio_dim=odim*2)