#!/usr/bin/env python -u # -*- coding: utf-8 -*- # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn class CLayerNorm(nn.LayerNorm): """Channel-wise layer normalization.""" def __init__(self, *args, **kwargs): super(CLayerNorm, self).__init__(*args, **kwargs) def forward(self, sample): """Forward function. Args: sample: [batch_size, channels, length] """ if sample.dim() != 3: raise RuntimeError('{} only accept 3-D tensor as input'.format( self.__name__)) # [N, C, T] -> [N, T, C] sample = torch.transpose(sample, 1, 2) # LayerNorm sample = super().forward(sample) # [N, T, C] -> [N, C, T] sample = torch.transpose(sample, 1, 2) return sample class ILayerNorm(nn.InstanceNorm1d): """Channel-wise layer normalization.""" def __init__(self, *args, **kwargs): super(ILayerNorm, self).__init__(*args, **kwargs) def forward(self, sample): """Forward function. Args: sample: [batch_size, channels, length] """ if sample.dim() != 3: raise RuntimeError('{} only accept 3-D tensor as input'.format( self.__name__)) # [N, C, T] -> [N, T, C] sample = torch.transpose(sample, 1, 2) # LayerNorm sample = super().forward(sample) # [N, T, C] -> [N, C, T] sample = torch.transpose(sample, 1, 2) return sample class GLayerNorm(nn.Module): """Global Layer Normalization for TasNet.""" def __init__(self, channels, eps=1e-5): super(GLayerNorm, self).__init__() self.eps = eps self.norm_dim = channels self.gamma = nn.Parameter(torch.Tensor(channels)) self.beta = nn.Parameter(torch.Tensor(channels)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.gamma) nn.init.zeros_(self.beta) def forward(self, sample): """Forward function. Args: sample: [batch_size, channels, length] """ if sample.dim() != 3: raise RuntimeError('{} only accept 3-D tensor as input'.format( self.__name__)) # [N, C, T] -> [N, T, C] sample = torch.transpose(sample, 1, 2) # Mean and variance [N, 1, 1] mean = torch.mean(sample, (1, 2), keepdim=True) var = torch.mean((sample - mean)**2, (1, 2), keepdim=True) sample = (sample - mean) / torch.sqrt(var + self.eps) * \ self.gamma + self.beta # [N, T, C] -> [N, C, T] sample = torch.transpose(sample, 1, 2) return sample class _LayerNorm(nn.Module): """Layer Normalization base class.""" def __init__(self, channel_size): super(_LayerNorm, self).__init__() self.channel_size = channel_size self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True) def apply_gain_and_bias(self, normed_x): """ Assumes input of size `[batch, chanel, *]`. """ return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) class GlobLayerNorm(_LayerNorm): """Global Layer Normalization (globLN).""" def forward(self, x): """ Applies forward pass. Works for any input size > 2D. Args: x (:class:`torch.Tensor`): Shape `[batch, chan, *]` Returns: :class:`torch.Tensor`: gLN_x `[batch, chan, *]` """ dims = list(range(1, len(x.shape))) mean = x.mean(dim=dims, keepdim=True) var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())