File size: 4,384 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b7e94
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b7e94
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b7e94
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-

# Copyright 2019 Tomoki Hayashi
#  MIT License (https://opensource.org/licenses/MIT)


import torch
import torch.nn.functional as F
from torch import nn

from ..hparams import HParams


def _make_stft_cfg(hop_length, win_length=None):
    if win_length is None:
        win_length = 4 * hop_length
    n_fft = 2 ** (win_length - 1).bit_length()
    return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)


def get_stft_cfgs(hp: HParams):
    assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
    return [_make_stft_cfg(h) for h in (100, 256, 512)]


def stft(x, n_fft, hop_length, win_length, window):
    dtype = x.dtype
    x = torch.stft(
        x.float(), n_fft, hop_length, win_length, window, return_complex=True
    )
    x = x.abs().to(dtype)
    x = x.transpose(2, 1)  # (b f t) -> (b t f)
    return x


class SpectralConvergengeLoss(nn.Module):
    def forward(self, x_mag, y_mag):
        """Calculate forward propagation.
        Args:
            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        Returns:
            Tensor: Spectral convergence loss value.
        """
        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")


class LogSTFTMagnitudeLoss(nn.Module):
    def forward(self, x_mag, y_mag):
        """Calculate forward propagation.
        Args:
            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
        Returns:
            Tensor: Log STFT magnitude loss value.
        """
        return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))


class STFTLoss(nn.Module):
    def __init__(self, hp, stft_cfg: dict, window="hann_window"):
        super().__init__()
        self.hp = hp
        self.stft_cfg = stft_cfg
        self.spectral_convergenge_loss = SpectralConvergengeLoss()
        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
        self.register_buffer(
            "window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False
        )

    def forward(self, x, y):
        """Calculate forward propagation.
        Args:
            x (Tensor): Predicted signal (B, T).
            y (Tensor): Groundtruth signal (B, T).
        Returns:
            Tensor: Spectral convergence loss value.
            Tensor: Log STFT magnitude loss value.
        """
        stft_cfg = dict(self.stft_cfg)
        x_mag = stft(x, **stft_cfg, window=self.window)  # (b t) -> (b t f)
        y_mag = stft(y, **stft_cfg, window=self.window)
        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
        return dict(sc=sc_loss, mag=mag_loss)


class MRSTFTLoss(nn.Module):
    def __init__(self, hp: HParams, window="hann_window"):
        """Initialize Multi resolution STFT loss module.
        Args:
            resolutions (list): List of (FFT size, hop size, window length).
            window (str): Window function type.
        """
        super().__init__()
        stft_cfgs = get_stft_cfgs(hp)
        self.stft_losses = nn.ModuleList()
        self.hp = hp
        for c in stft_cfgs:
            self.stft_losses += [STFTLoss(hp, c, window=window)]

    def forward(self, x, y):
        """Calculate forward propagation.
        Args:
            x (Tensor): Predicted signal (b t).
            y (Tensor): Groundtruth signal (b t).
        Returns:
            Tensor: Multi resolution spectral convergence loss value.
            Tensor: Multi resolution log STFT magnitude loss value.
        """
        assert (
            x.dim() == 2 and y.dim() == 2
        ), f"(b t) is expected, but got {x.shape} and {y.shape}."

        dtype = x.dtype

        x = x.float()
        y = y.float()

        # Align length
        x = x[..., : y.shape[-1]]
        y = y[..., : x.shape[-1]]

        losses = {}

        for f in self.stft_losses:
            d = f(x, y)
            for k, v in d.items():
                losses.setdefault(k, []).append(v)

        for k, v in losses.items():
            losses[k] = torch.stack(v, dim=0).mean().to(dtype)

        return losses