alibabasglab's picture
Upload 161 files
8e8cd3e verified
raw
history blame
4.08 kB
#!/usr/bin/env python -u
# -*- coding: utf-8 -*-
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
class CLayerNorm(nn.LayerNorm):
"""Channel-wise layer normalization."""
def __init__(self, *args, **kwargs):
super(CLayerNorm, self).__init__(*args, **kwargs)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# LayerNorm
sample = super().forward(sample)
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class ILayerNorm(nn.InstanceNorm1d):
"""Channel-wise layer normalization."""
def __init__(self, *args, **kwargs):
super(ILayerNorm, self).__init__(*args, **kwargs)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# LayerNorm
sample = super().forward(sample)
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class GLayerNorm(nn.Module):
"""Global Layer Normalization for TasNet."""
def __init__(self, channels, eps=1e-5):
super(GLayerNorm, self).__init__()
self.eps = eps
self.norm_dim = channels
self.gamma = nn.Parameter(torch.Tensor(channels))
self.beta = nn.Parameter(torch.Tensor(channels))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# Mean and variance [N, 1, 1]
mean = torch.mean(sample, (1, 2), keepdim=True)
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
sample = (sample - mean) / torch.sqrt(var + self.eps) * \
self.gamma + self.beta
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class _LayerNorm(nn.Module):
"""Layer Normalization base class."""
def __init__(self, channel_size):
super(_LayerNorm, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size),
requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size),
requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) +
self.beta).transpose(1, -1)
class GlobLayerNorm(_LayerNorm):
"""Global Layer Normalization (globLN)."""
def forward(self, x):
""" Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
"""
dims = list(range(1, len(x.shape)))
mean = x.mean(dim=dims, keepdim=True)
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())