Spaces:
Runtime error
Runtime error
from typing import List, NoReturn | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def init_embedding(layer: nn.Module) -> NoReturn: | |
r"""Initialize a Linear or Convolutional layer.""" | |
nn.init.uniform_(layer.weight, -1.0, 1.0) | |
if hasattr(layer, 'bias'): | |
if layer.bias is not None: | |
layer.bias.data.fill_(0.0) | |
def init_layer(layer: nn.Module) -> NoReturn: | |
r"""Initialize a Linear or Convolutional layer.""" | |
nn.init.xavier_uniform_(layer.weight) | |
if hasattr(layer, "bias"): | |
if layer.bias is not None: | |
layer.bias.data.fill_(0.0) | |
def init_bn(bn: nn.Module) -> NoReturn: | |
r"""Initialize a Batchnorm layer.""" | |
bn.bias.data.fill_(0.0) | |
bn.weight.data.fill_(1.0) | |
bn.running_mean.data.fill_(0.0) | |
bn.running_var.data.fill_(1.0) | |
def act(x: torch.Tensor, activation: str) -> torch.Tensor: | |
if activation == "relu": | |
return F.relu_(x) | |
elif activation == "leaky_relu": | |
return F.leaky_relu_(x, negative_slope=0.01) | |
elif activation == "swish": | |
return x * torch.sigmoid(x) | |
else: | |
raise Exception("Incorrect activation!") | |
class Base: | |
def __init__(self): | |
r"""Base function for extracting spectrogram, cos, and sin, etc.""" | |
pass | |
def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor: | |
r"""Calculate spectrogram. | |
Args: | |
input: (batch_size, segments_num) | |
eps: float | |
Returns: | |
spectrogram: (batch_size, time_steps, freq_bins) | |
""" | |
(real, imag) = self.stft(input) | |
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
def spectrogram_phase( | |
self, input: torch.Tensor, eps: float = 0.0 | |
) -> List[torch.Tensor]: | |
r"""Calculate the magnitude, cos, and sin of the STFT of input. | |
Args: | |
input: (batch_size, segments_num) | |
eps: float | |
Returns: | |
mag: (batch_size, time_steps, freq_bins) | |
cos: (batch_size, time_steps, freq_bins) | |
sin: (batch_size, time_steps, freq_bins) | |
""" | |
(real, imag) = self.stft(input) | |
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
cos = real / mag | |
sin = imag / mag | |
return mag, cos, sin | |
def wav_to_spectrogram_phase( | |
self, input: torch.Tensor, eps: float = 1e-10 | |
) -> List[torch.Tensor]: | |
r"""Convert waveforms to magnitude, cos, and sin of STFT. | |
Args: | |
input: (batch_size, channels_num, segment_samples) | |
eps: float | |
Outputs: | |
mag: (batch_size, channels_num, time_steps, freq_bins) | |
cos: (batch_size, channels_num, time_steps, freq_bins) | |
sin: (batch_size, channels_num, time_steps, freq_bins) | |
""" | |
batch_size, channels_num, segment_samples = input.shape | |
# Reshape input with shapes of (n, segments_num) to meet the | |
# requirements of the stft function. | |
x = input.reshape(batch_size * channels_num, segment_samples) | |
mag, cos, sin = self.spectrogram_phase(x, eps=eps) | |
# mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins) | |
_, _, time_steps, freq_bins = mag.shape | |
mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins) | |
cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins) | |
sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins) | |
return mag, cos, sin | |
def wav_to_spectrogram( | |
self, input: torch.Tensor, eps: float = 1e-10 | |
) -> List[torch.Tensor]: | |
mag, cos, sin = self.wav_to_spectrogram_phase(input, eps) | |
return mag | |
class Subband: | |
def __init__(self, subbands_num: int): | |
r"""Warning!! This class is not used!! | |
This class does not work as good as [1] which split subbands in the | |
time-domain. Please refere to [1] for formal implementation. | |
[1] Liu, Haohe, et al. "Channel-wise subband input for better voice and | |
accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020). | |
Args: | |
subbands_num: int, e.g., 4 | |
""" | |
self.subbands_num = subbands_num | |
def analysis(self, x: torch.Tensor) -> torch.Tensor: | |
r"""Analysis time-frequency representation into subbands. Stack the | |
subbands along the channel axis. | |
Args: | |
x: (batch_size, channels_num, time_steps, freq_bins) | |
Returns: | |
output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) | |
""" | |
batch_size, channels_num, time_steps, freq_bins = x.shape | |
x = x.reshape( | |
batch_size, | |
channels_num, | |
time_steps, | |
self.subbands_num, | |
freq_bins // self.subbands_num, | |
) | |
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num) | |
x = x.transpose(2, 3) | |
output = x.reshape( | |
batch_size, | |
channels_num * self.subbands_num, | |
time_steps, | |
freq_bins // self.subbands_num, | |
) | |
# output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) | |
return output | |
def synthesis(self, x: torch.Tensor) -> torch.Tensor: | |
r"""Synthesis subband time-frequency representations into original | |
time-frequency representation. | |
Args: | |
x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num) | |
Returns: | |
output: (batch_size, channels_num, time_steps, freq_bins) | |
""" | |
batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape | |
channels_num = subband_channels_num // self.subbands_num | |
freq_bins = subband_freq_bins * self.subbands_num | |
x = x.reshape( | |
batch_size, | |
channels_num, | |
self.subbands_num, | |
time_steps, | |
subband_freq_bins, | |
) | |
# x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num) | |
x = x.transpose(2, 3) | |
# x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num) | |
output = x.reshape(batch_size, channels_num, time_steps, freq_bins) | |
# x: (batch_size, channels_num, time_steps, freq_bins) | |
return output | |