Spaces:
Sleeping
Sleeping
File size: 4,384 Bytes
32b2aaa d2b7e94 32b2aaa d2b7e94 32b2aaa d2b7e94 32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
import torch
import torch.nn.functional as F
from torch import nn
from ..hparams import HParams
def _make_stft_cfg(hop_length, win_length=None):
if win_length is None:
win_length = 4 * hop_length
n_fft = 2 ** (win_length - 1).bit_length()
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def get_stft_cfgs(hp: HParams):
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
return [_make_stft_cfg(h) for h in (100, 256, 512)]
def stft(x, n_fft, hop_length, win_length, window):
dtype = x.dtype
x = torch.stft(
x.float(), n_fft, hop_length, win_length, window, return_complex=True
)
x = x.abs().to(dtype)
x = x.transpose(2, 1) # (b f t) -> (b t f)
return x
class SpectralConvergengeLoss(nn.Module):
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(nn.Module):
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
class STFTLoss(nn.Module):
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
super().__init__()
self.hp = hp
self.stft_cfg = stft_cfg
self.spectral_convergenge_loss = SpectralConvergengeLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
self.register_buffer(
"window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False
)
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
stft_cfg = dict(self.stft_cfg)
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
y_mag = stft(y, **stft_cfg, window=self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return dict(sc=sc_loss, mag=mag_loss)
class MRSTFTLoss(nn.Module):
def __init__(self, hp: HParams, window="hann_window"):
"""Initialize Multi resolution STFT loss module.
Args:
resolutions (list): List of (FFT size, hop size, window length).
window (str): Window function type.
"""
super().__init__()
stft_cfgs = get_stft_cfgs(hp)
self.stft_losses = nn.ModuleList()
self.hp = hp
for c in stft_cfgs:
self.stft_losses += [STFTLoss(hp, c, window=window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (b t).
y (Tensor): Groundtruth signal (b t).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
assert (
x.dim() == 2 and y.dim() == 2
), f"(b t) is expected, but got {x.shape} and {y.shape}."
dtype = x.dtype
x = x.float()
y = y.float()
# Align length
x = x[..., : y.shape[-1]]
y = y[..., : x.shape[-1]]
losses = {}
for f in self.stft_losses:
d = f(x, y)
for k, v in d.items():
losses.setdefault(k, []).append(v)
for k, v in losses.items():
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
return losses
|