| import torch | |
| from espnet2.enh.decoder.abs_decoder import AbsDecoder | |
| class ConvDecoder(AbsDecoder): | |
| """Transposed Convolutional decoder for speech enhancement and separation """ | |
| def __init__( | |
| self, | |
| channel: int, | |
| kernel_size: int, | |
| stride: int, | |
| ): | |
| super().__init__() | |
| self.convtrans1d = torch.nn.ConvTranspose1d( | |
| channel, 1, kernel_size, bias=False, stride=stride | |
| ) | |
| def forward(self, input: torch.Tensor, ilens: torch.Tensor): | |
| """Forward. | |
| Args: | |
| input (torch.Tensor): spectrum [Batch, T, F] | |
| ilens (torch.Tensor): input lengths [Batch] | |
| """ | |
| input = input.transpose(1, 2) | |
| batch_size = input.shape[0] | |
| wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max())) | |
| wav = wav.squeeze(1) | |
| return wav, ilens | |