Spaces:
Sleeping
Sleeping
File size: 1,616 Bytes
0883aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from modules.general.utils import Conv1d
class GaU(nn.Module):
r"""Gated Activation Unit (GaU) proposed in `Gated Activation Units for Neural
Networks <https://arxiv.org/pdf/1606.05328.pdf>`_.
Args:
channels: number of input channels.
kernel_size: kernel size of the convolution.
dilation: dilation rate of the convolution.
d_context: dimension of context tensor, None if don't use context.
"""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: int = 1,
d_context: int = None,
):
super().__init__()
self.context = d_context
self.conv = Conv1d(
channels,
channels * 2,
kernel_size,
dilation=dilation,
padding=dilation * (kernel_size - 1) // 2,
)
if self.context:
self.context_proj = Conv1d(d_context, channels * 2, 1)
def forward(self, x: torch.Tensor, context: torch.Tensor = None):
r"""Calculate forward propagation.
Args:
x: input tensor with shape [B, C, T].
context: context tensor with shape [B, ``d_context``, T], default to None.
"""
h = self.conv(x)
if self.context:
h = h + self.context_proj(context)
h1, h2 = h.chunk(2, 1)
h = torch.tanh(h1) * torch.sigmoid(h2)
return h
|