Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from typing import Optional, List | |
def extract_chunks(A: Tensor, | |
ofx: Tensor, | |
mel_ids: Optional[Tensor] = None, | |
chunk_len: int = 128): | |
""" | |
Args: | |
A (Tensor): spectrograms [B, F, T] | |
ofx (Tensor): offsets [num_chunks,] | |
mel_ids (Tensor): [num_chunks,] | |
Returns: | |
chunks (Tensor): [num_chunks, F, chunk_len] | |
""" | |
ids = torch.arange(0, chunk_len, device=A.device)[None,:].repeat(len(mel_ids), 1) + ofx[:,None] | |
if mel_ids is None: | |
mel_ids = torch.arange(0, A.size(0), device=A.device)[:,None] * A.size(2) | |
ids = ids + mel_ids[:,None] * A.size(2) | |
chunks = A.transpose(0, 1).flatten(1)[:, ids.long()].transpose(0, 1) | |
return chunks | |
def calc_feature_match_loss(fmaps_gen: List[Tensor], | |
fmaps_org: List[Tensor] | |
): | |
loss_fmatch = 0. | |
for (fmap_gen, fmap_org) in zip(fmaps_gen, fmaps_org): | |
fmap_org.detach_() | |
loss_fmatch += (fmap_gen - fmap_org).abs().mean() | |
loss_fmatch = loss_fmatch / len(fmaps_gen) | |
return loss_fmatch | |
class Conv2DSpectralNorm(nn.Conv2d): | |
"""Convolution layer that applies Spectral Normalization before every call.""" | |
def __init__(self, cnum_in: int, cnum_out: int, | |
kernel_size: int, stride: int, padding: int = 0, | |
n_iter: int = 1, eps: float = 1e-12, | |
bias: bool = True): | |
super().__init__(cnum_in, | |
cnum_out, kernel_size=kernel_size, | |
stride=stride, padding=padding, bias=bias) | |
self.register_buffer("weight_u", torch.empty(self.weight.size(0), 1)) | |
nn.init.trunc_normal_(self.weight_u) | |
self.n_iter = n_iter | |
self.eps = eps | |
def l2_norm(self, x): | |
return F.normalize(x, p=2, dim=0, eps=self.eps) | |
def forward(self, x): | |
weight_orig = self.weight.flatten(1).detach() | |
for _ in range(self.n_iter): | |
v = self.l2_norm(weight_orig.t() @ self.weight_u) | |
self.weight_u = self.l2_norm(weight_orig @ v) | |
sigma = self.weight_u.t() @ weight_orig @ v | |
self.weight.data.div_(sigma) | |
x = super().forward(x) | |
return x | |
class DConv(nn.Module): | |
def __init__(self, cnum_in, | |
cnum_out, ksize=5, stride=2, padding='auto'): | |
super().__init__() | |
padding = (ksize-1)//2 if padding == 'auto' else padding | |
self.conv_sn = Conv2DSpectralNorm( | |
cnum_in, cnum_out, ksize, stride, padding) | |
#self.conv_sn = spectral_norm(nn.Conv2d(cnum_in, cnum_out, ksize, stride, padding)) | |
self.leaky = nn.LeakyReLU(negative_slope=0.2) | |
def forward(self, x): | |
x = self.conv_sn(x) | |
x = self.leaky(x) | |
return x | |
class PatchDiscriminator(nn.Module): | |
def __init__(self, cnum_in, cnum): | |
super().__init__() | |
self.conv1 = DConv(cnum_in, cnum) | |
self.conv2 = DConv(cnum, 2*cnum) | |
self.conv3 = DConv(2*cnum, 4*cnum) | |
self.conv4 = DConv(4*cnum, 4*cnum) | |
self.conv5 = DConv(4*cnum, 4*cnum) | |
def forward(self, x): | |
x1 = self.conv1(x) | |
x2 = self.conv2(x1) | |
x3 = self.conv3(x2) | |
x4 = self.conv4(x3) | |
x5 = self.conv5(x4) | |
x = nn.Flatten()(x5) | |
return x, [x1, x2, x3, x4] |