Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.nn import Conv1d | |
from torch.nn import ConvTranspose1d | |
from torch.nn.utils import weight_norm | |
from torch.nn.utils import remove_weight_norm | |
from .nsf import SourceModuleHnNSF | |
from .bigv import init_weights, AMPBlock, SnakeAlias | |
class Generator(torch.nn.Module): | |
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. | |
def __init__(self, hp): | |
super(Generator, self).__init__() | |
self.hp = hp | |
self.num_kernels = len(hp.gen.resblock_kernel_sizes) | |
self.num_upsamples = len(hp.gen.upsample_rates) | |
# pre conv | |
self.conv_pre = nn.utils.weight_norm( | |
Conv1d(hp.gen.mel_channels, hp.gen.upsample_initial_channel, 7, 1, padding=3)) | |
# nsf | |
self.f0_upsamp = torch.nn.Upsample( | |
scale_factor=np.prod(hp.gen.upsample_rates)) | |
self.m_source = SourceModuleHnNSF(sampling_rate=hp.audio.sampling_rate) | |
self.noise_convs = nn.ModuleList() | |
# transposed conv-based upsamplers. does not apply anti-aliasing | |
self.ups = nn.ModuleList() | |
for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)): | |
# print(f'ups: {i} {k}, {u}, {(k - u) // 2}') | |
# base | |
self.ups.append( | |
weight_norm( | |
ConvTranspose1d( | |
hp.gen.upsample_initial_channel // (2 ** i), | |
hp.gen.upsample_initial_channel // (2 ** (i + 1)), | |
k, | |
u, | |
padding=(k - u) // 2) | |
) | |
) | |
# nsf | |
if i + 1 < len(hp.gen.upsample_rates): | |
stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:]) | |
stride_f0 = int(stride_f0) | |
self.noise_convs.append( | |
Conv1d( | |
1, | |
hp.gen.upsample_initial_channel // (2 ** (i + 1)), | |
kernel_size=stride_f0 * 2, | |
stride=stride_f0, | |
padding=stride_f0 // 2, | |
) | |
) | |
else: | |
self.noise_convs.append( | |
Conv1d(1, hp.gen.upsample_initial_channel // | |
(2 ** (i + 1)), kernel_size=1) | |
) | |
# residual blocks using anti-aliased multi-periodicity composition modules (AMP) | |
self.resblocks = nn.ModuleList() | |
for i in range(len(self.ups)): | |
ch = hp.gen.upsample_initial_channel // (2 ** (i + 1)) | |
for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes): | |
self.resblocks.append(AMPBlock(ch, k, d)) | |
# post conv | |
self.activation_post = SnakeAlias(ch) | |
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) | |
# weight initialization | |
self.ups.apply(init_weights) | |
def forward(self, x, f0, train=True): | |
# nsf | |
f0 = f0[:, None] | |
f0 = self.f0_upsamp(f0).transpose(1, 2) | |
har_source = self.m_source(f0) | |
har_source = har_source.transpose(1, 2) | |
# pre conv | |
if train: | |
x = x + torch.randn_like(x) * 0.1 # Perturbation | |
x = self.conv_pre(x) | |
x = x * torch.tanh(F.softplus(x)) | |
for i in range(self.num_upsamples): | |
# upsampling | |
x = self.ups[i](x) | |
# nsf | |
x_source = self.noise_convs[i](har_source) | |
x = x + x_source | |
# AMP blocks | |
xs = None | |
for j in range(self.num_kernels): | |
if xs is None: | |
xs = self.resblocks[i * self.num_kernels + j](x) | |
else: | |
xs += self.resblocks[i * self.num_kernels + j](x) | |
x = xs / self.num_kernels | |
# post conv | |
x = self.activation_post(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |
def remove_weight_norm(self): | |
for l in self.ups: | |
remove_weight_norm(l) | |
for l in self.resblocks: | |
l.remove_weight_norm() | |
remove_weight_norm(self.conv_pre) | |
def eval(self, inference=False): | |
super(Generator, self).eval() | |
# don't remove weight norm while validation in training loop | |
if inference: | |
self.remove_weight_norm() | |
def inference(self, mel, f0): | |
MAX_WAV_VALUE = 32768.0 | |
audio = self.forward(mel, f0, False) | |
audio = audio.squeeze() # collapse all dimension except time axis | |
audio = MAX_WAV_VALUE * audio | |
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
audio = audio.short() | |
return audio | |
def pitch2wav(self, f0): | |
MAX_WAV_VALUE = 32768.0 | |
# nsf | |
f0 = f0[:, None] | |
f0 = self.f0_upsamp(f0).transpose(1, 2) | |
har_source = self.m_source(f0) | |
audio = har_source.transpose(1, 2) | |
audio = audio.squeeze() # collapse all dimension except time axis | |
audio = MAX_WAV_VALUE * audio | |
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) | |
audio = audio.short() | |
return audio | |