File size: 1,787 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn


class MelSpectrogram(nn.Module):
    def __init__(self, 
                 sample_rate: int = 22050,
                 n_fft: int = 1024,
                 win_length: int = 1024,
                 hop_length: int = 256,
                 n_mels: int = 80,
                 f_min: float = 0,
                 f_max: float = 8000.0,
                 norm: str = 'slaney',
                 center: bool = False
                 ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.center = center

        self.pad_length = int((n_fft - hop_length)/2)

        mel_basis = torch.Tensor(librosa_mel_fn(sr=sample_rate,
                                                n_fft=n_fft, n_mels=n_mels,
                                                fmin=f_min, fmax=f_max, 
                                                norm=norm))
        window_fn = torch.hann_window(win_length)
        self.register_buffer('mel_basis', mel_basis)
        self.register_buffer('window_fn', window_fn)

    def forward(self, x):
        x_pad = torch.nn.functional.pad(
            x, (self.pad_length, self.pad_length), mode='reflect')
        spec_lin = torch.stft(x_pad, self.n_fft,
                              self.hop_length,
                              self.win_length,
                              self.window_fn,
                              center=self.center,
                              return_complex=True) # [B, F, T]
        spec_mag = spec_lin.abs().pow_(2).add_(1e-9).sqrt_()
        spec_mel = torch.matmul(self.mel_basis, spec_mag) # [B, mels, T]
        return spec_mel