wasmdashai's picture
first commit
7694c84
raw
history blame
3.47 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, List
def extract_chunks(A: Tensor,
ofx: Tensor,
mel_ids: Optional[Tensor] = None,
chunk_len: int = 128):
"""
Args:
A (Tensor): spectrograms [B, F, T]
ofx (Tensor): offsets [num_chunks,]
mel_ids (Tensor): [num_chunks,]
Returns:
chunks (Tensor): [num_chunks, F, chunk_len]
"""
ids = torch.arange(0, chunk_len, device=A.device)[None,:].repeat(len(mel_ids), 1) + ofx[:,None]
if mel_ids is None:
mel_ids = torch.arange(0, A.size(0), device=A.device)[:,None] * A.size(2)
ids = ids + mel_ids[:,None] * A.size(2)
chunks = A.transpose(0, 1).flatten(1)[:, ids.long()].transpose(0, 1)
return chunks
def calc_feature_match_loss(fmaps_gen: List[Tensor],
fmaps_org: List[Tensor]
):
loss_fmatch = 0.
for (fmap_gen, fmap_org) in zip(fmaps_gen, fmaps_org):
fmap_org.detach_()
loss_fmatch += (fmap_gen - fmap_org).abs().mean()
loss_fmatch = loss_fmatch / len(fmaps_gen)
return loss_fmatch
class Conv2DSpectralNorm(nn.Conv2d):
"""Convolution layer that applies Spectral Normalization before every call."""
def __init__(self, cnum_in: int, cnum_out: int,
kernel_size: int, stride: int, padding: int = 0,
n_iter: int = 1, eps: float = 1e-12,
bias: bool = True):
super().__init__(cnum_in,
cnum_out, kernel_size=kernel_size,
stride=stride, padding=padding, bias=bias)
self.register_buffer("weight_u", torch.empty(self.weight.size(0), 1))
nn.init.trunc_normal_(self.weight_u)
self.n_iter = n_iter
self.eps = eps
def l2_norm(self, x):
return F.normalize(x, p=2, dim=0, eps=self.eps)
def forward(self, x):
weight_orig = self.weight.flatten(1).detach()
for _ in range(self.n_iter):
v = self.l2_norm(weight_orig.t() @ self.weight_u)
self.weight_u = self.l2_norm(weight_orig @ v)
sigma = self.weight_u.t() @ weight_orig @ v
self.weight.data.div_(sigma)
x = super().forward(x)
return x
class DConv(nn.Module):
def __init__(self, cnum_in,
cnum_out, ksize=5, stride=2, padding='auto'):
super().__init__()
padding = (ksize-1)//2 if padding == 'auto' else padding
self.conv_sn = Conv2DSpectralNorm(
cnum_in, cnum_out, ksize, stride, padding)
#self.conv_sn = spectral_norm(nn.Conv2d(cnum_in, cnum_out, ksize, stride, padding))
self.leaky = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
x = self.conv_sn(x)
x = self.leaky(x)
return x
class PatchDiscriminator(nn.Module):
def __init__(self, cnum_in, cnum):
super().__init__()
self.conv1 = DConv(cnum_in, cnum)
self.conv2 = DConv(cnum, 2*cnum)
self.conv3 = DConv(2*cnum, 4*cnum)
self.conv4 = DConv(4*cnum, 4*cnum)
self.conv5 = DConv(4*cnum, 4*cnum)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x = nn.Flatten()(x5)
return x, [x1, x2, x3, x4]