Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python -u | |
# -*- coding: utf-8 -*- | |
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
class CLayerNorm(nn.LayerNorm): | |
"""Channel-wise layer normalization.""" | |
def __init__(self, *args, **kwargs): | |
super(CLayerNorm, self).__init__(*args, **kwargs) | |
def forward(self, sample): | |
"""Forward function. | |
Args: | |
sample: [batch_size, channels, length] | |
""" | |
if sample.dim() != 3: | |
raise RuntimeError('{} only accept 3-D tensor as input'.format( | |
self.__name__)) | |
# [N, C, T] -> [N, T, C] | |
sample = torch.transpose(sample, 1, 2) | |
# LayerNorm | |
sample = super().forward(sample) | |
# [N, T, C] -> [N, C, T] | |
sample = torch.transpose(sample, 1, 2) | |
return sample | |
class ILayerNorm(nn.InstanceNorm1d): | |
"""Channel-wise layer normalization.""" | |
def __init__(self, *args, **kwargs): | |
super(ILayerNorm, self).__init__(*args, **kwargs) | |
def forward(self, sample): | |
"""Forward function. | |
Args: | |
sample: [batch_size, channels, length] | |
""" | |
if sample.dim() != 3: | |
raise RuntimeError('{} only accept 3-D tensor as input'.format( | |
self.__name__)) | |
# [N, C, T] -> [N, T, C] | |
sample = torch.transpose(sample, 1, 2) | |
# LayerNorm | |
sample = super().forward(sample) | |
# [N, T, C] -> [N, C, T] | |
sample = torch.transpose(sample, 1, 2) | |
return sample | |
class GLayerNorm(nn.Module): | |
"""Global Layer Normalization for TasNet.""" | |
def __init__(self, channels, eps=1e-5): | |
super(GLayerNorm, self).__init__() | |
self.eps = eps | |
self.norm_dim = channels | |
self.gamma = nn.Parameter(torch.Tensor(channels)) | |
self.beta = nn.Parameter(torch.Tensor(channels)) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.ones_(self.gamma) | |
nn.init.zeros_(self.beta) | |
def forward(self, sample): | |
"""Forward function. | |
Args: | |
sample: [batch_size, channels, length] | |
""" | |
if sample.dim() != 3: | |
raise RuntimeError('{} only accept 3-D tensor as input'.format( | |
self.__name__)) | |
# [N, C, T] -> [N, T, C] | |
sample = torch.transpose(sample, 1, 2) | |
# Mean and variance [N, 1, 1] | |
mean = torch.mean(sample, (1, 2), keepdim=True) | |
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True) | |
sample = (sample - mean) / torch.sqrt(var + self.eps) * \ | |
self.gamma + self.beta | |
# [N, T, C] -> [N, C, T] | |
sample = torch.transpose(sample, 1, 2) | |
return sample | |
class _LayerNorm(nn.Module): | |
"""Layer Normalization base class.""" | |
def __init__(self, channel_size): | |
super(_LayerNorm, self).__init__() | |
self.channel_size = channel_size | |
self.gamma = nn.Parameter(torch.ones(channel_size), | |
requires_grad=True) | |
self.beta = nn.Parameter(torch.zeros(channel_size), | |
requires_grad=True) | |
def apply_gain_and_bias(self, normed_x): | |
""" Assumes input of size `[batch, chanel, *]`. """ | |
return (self.gamma * normed_x.transpose(1, -1) + | |
self.beta).transpose(1, -1) | |
class GlobLayerNorm(_LayerNorm): | |
"""Global Layer Normalization (globLN).""" | |
def forward(self, x): | |
""" Applies forward pass. | |
Works for any input size > 2D. | |
Args: | |
x (:class:`torch.Tensor`): Shape `[batch, chan, *]` | |
Returns: | |
:class:`torch.Tensor`: gLN_x `[batch, chan, *]` | |
""" | |
dims = list(range(1, len(x.shape))) | |
mean = x.mean(dim=dims, keepdim=True) | |
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) | |
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt()) | |