Spaces:
Running
Running
File size: 953 Bytes
39f384e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
class LayerNorm(torch.nn.Module):
"""Layer normalization module.
Args:
channels (int): Number of channels.
eps (float, optional): Epsilon value for numerical stability. Defaults to 1e-5.
"""
def __init__(self, channels: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
"""Forward pass.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, time_steps).
"""
# Transpose to (batch_size, time_steps, channels) for layer_norm
x = x.transpose(1, -1)
x = torch.nn.functional.layer_norm(
x, (x.size(-1),), self.gamma, self.beta, self.eps
)
# Transpose back to (batch_size, channels, time_steps)
return x.transpose(1, -1)
|