Commit
·
fd06d88
1
Parent(s):
847e2fc
upload infer utils
Browse files- wav2filterbank.py +313 -0
- xvector_sincnet.py +223 -0
wav2filterbank.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import logging
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
CONSTANT = 1e-5
|
| 10 |
+
|
| 11 |
+
def normalize_batch(x, seq_len, normalize_type):
|
| 12 |
+
x_mean = None
|
| 13 |
+
x_std = None
|
| 14 |
+
if normalize_type == "per_feature":
|
| 15 |
+
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
| 16 |
+
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
| 17 |
+
for i in range(x.shape[0]):
|
| 18 |
+
if x[i, :, : seq_len[i]].shape[1] == 1:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
|
| 21 |
+
"in torch.std() returning nan. Make sure your audio length has enough samples for a single "
|
| 22 |
+
"feature (ex. at least `hop_length` for Mel Spectrograms)."
|
| 23 |
+
)
|
| 24 |
+
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
|
| 25 |
+
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
|
| 26 |
+
# make sure x_std is not zero
|
| 27 |
+
x_std += CONSTANT
|
| 28 |
+
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
|
| 29 |
+
elif normalize_type == "all_features":
|
| 30 |
+
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
| 31 |
+
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
|
| 32 |
+
for i in range(x.shape[0]):
|
| 33 |
+
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
|
| 34 |
+
x_std[i] = x[i, :, : seq_len[i].item()].std()
|
| 35 |
+
# make sure x_std is not zero
|
| 36 |
+
x_std += CONSTANT
|
| 37 |
+
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
|
| 38 |
+
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
|
| 39 |
+
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
|
| 40 |
+
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
|
| 41 |
+
return (
|
| 42 |
+
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
|
| 43 |
+
x_mean,
|
| 44 |
+
x_std,
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
return x, x_mean, x_std
|
| 48 |
+
|
| 49 |
+
def splice_frames(x, frame_splicing):
|
| 50 |
+
""" Stacks frames together across feature dim
|
| 51 |
+
|
| 52 |
+
input is batch_size, feature_dim, num_frames
|
| 53 |
+
output is batch_size, feature_dim*frame_splicing, num_frames
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
seq = [x]
|
| 57 |
+
for n in range(1, frame_splicing):
|
| 58 |
+
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
|
| 59 |
+
return torch.cat(seq, dim=1)
|
| 60 |
+
|
| 61 |
+
class FilterbankFeatures(nn.Module):
|
| 62 |
+
"""Featurizer that converts wavs to Mel Spectrograms.
|
| 63 |
+
See AudioToMelSpectrogramPreprocessor for args.
|
| 64 |
+
|
| 65 |
+
"normalize": "per_feature",
|
| 66 |
+
"window_size": 0.025,
|
| 67 |
+
"sample_rate": 16000,
|
| 68 |
+
"window_stride": 0.01,
|
| 69 |
+
"window": "hann",
|
| 70 |
+
"features": 80,
|
| 71 |
+
"n_fft": 512,
|
| 72 |
+
"frame_splicing": 1,
|
| 73 |
+
"dither": 1e-05
|
| 74 |
+
|
| 75 |
+
n_window_size=window_size * sample_rate,
|
| 76 |
+
n_window_stride = window_stride * sample_rate,
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
sample_rate=16000,
|
| 82 |
+
n_window_size=400,
|
| 83 |
+
n_window_stride=160,
|
| 84 |
+
window="hann",
|
| 85 |
+
normalize="per_feature",
|
| 86 |
+
n_fft=512,
|
| 87 |
+
preemph=0.97,
|
| 88 |
+
nfilt=80,
|
| 89 |
+
lowfreq=0,
|
| 90 |
+
highfreq=None,
|
| 91 |
+
log=True,
|
| 92 |
+
log_zero_guard_type="add",
|
| 93 |
+
log_zero_guard_value=2 ** -24,
|
| 94 |
+
dither=CONSTANT,
|
| 95 |
+
pad_to=16,
|
| 96 |
+
max_duration=16.7,
|
| 97 |
+
frame_splicing=1,
|
| 98 |
+
exact_pad=False,
|
| 99 |
+
pad_value=0,
|
| 100 |
+
mag_power=2.0,
|
| 101 |
+
use_grads=False,
|
| 102 |
+
rng=None,
|
| 103 |
+
nb_augmentation_prob=0.0,
|
| 104 |
+
nb_max_freq=4000,
|
| 105 |
+
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
|
| 106 |
+
stft_conv=False, # Deprecated arguments; kept for config compatibility
|
| 107 |
+
):
|
| 108 |
+
super().__init__()
|
| 109 |
+
if stft_conv or stft_exact_pad:
|
| 110 |
+
logging.warning(
|
| 111 |
+
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
|
| 112 |
+
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
|
| 113 |
+
"as needed."
|
| 114 |
+
)
|
| 115 |
+
if exact_pad and n_window_stride % 2 == 1:
|
| 116 |
+
raise NotImplementedError(
|
| 117 |
+
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
|
| 118 |
+
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
|
| 119 |
+
)
|
| 120 |
+
self.log_zero_guard_value = log_zero_guard_value
|
| 121 |
+
if (
|
| 122 |
+
n_window_size is None
|
| 123 |
+
or n_window_stride is None
|
| 124 |
+
or not isinstance(n_window_size, int)
|
| 125 |
+
or not isinstance(n_window_stride, int)
|
| 126 |
+
or n_window_size <= 0
|
| 127 |
+
or n_window_stride <= 0
|
| 128 |
+
):
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"{self} got an invalid value for either n_window_size or "
|
| 131 |
+
f"n_window_stride. Both must be positive ints."
|
| 132 |
+
)
|
| 133 |
+
logging.info(f"PADDING: {pad_to}")
|
| 134 |
+
|
| 135 |
+
self.win_length = n_window_size
|
| 136 |
+
self.hop_length = n_window_stride
|
| 137 |
+
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
|
| 138 |
+
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
|
| 139 |
+
|
| 140 |
+
if exact_pad:
|
| 141 |
+
logging.info("STFT using exact pad")
|
| 142 |
+
torch_windows = {
|
| 143 |
+
'hann': torch.hann_window,
|
| 144 |
+
'hamming': torch.hamming_window,
|
| 145 |
+
'blackman': torch.blackman_window,
|
| 146 |
+
'bartlett': torch.bartlett_window,
|
| 147 |
+
'none': None,
|
| 148 |
+
}
|
| 149 |
+
window_fn = torch_windows.get(window, None)
|
| 150 |
+
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
|
| 151 |
+
self.register_buffer("window", window_tensor)
|
| 152 |
+
self.stft = lambda x: torch.stft(
|
| 153 |
+
x,
|
| 154 |
+
n_fft=self.n_fft,
|
| 155 |
+
hop_length=self.hop_length,
|
| 156 |
+
win_length=self.win_length,
|
| 157 |
+
center=False if exact_pad else True,
|
| 158 |
+
window=self.window.to(dtype=torch.float),
|
| 159 |
+
return_complex=True,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.normalize = normalize
|
| 163 |
+
self.log = log
|
| 164 |
+
self.dither = dither
|
| 165 |
+
self.frame_splicing = frame_splicing
|
| 166 |
+
self.nfilt = nfilt
|
| 167 |
+
self.preemph = preemph
|
| 168 |
+
self.pad_to = pad_to
|
| 169 |
+
highfreq = highfreq or sample_rate / 2
|
| 170 |
+
|
| 171 |
+
filterbanks = torch.tensor(
|
| 172 |
+
librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq),
|
| 173 |
+
dtype=torch.float,
|
| 174 |
+
).unsqueeze(0)
|
| 175 |
+
self.register_buffer("fb", filterbanks)
|
| 176 |
+
|
| 177 |
+
# Calculate maximum sequence length
|
| 178 |
+
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
|
| 179 |
+
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
|
| 180 |
+
self.max_length = max_length + max_pad
|
| 181 |
+
self.pad_value = pad_value
|
| 182 |
+
self.mag_power = mag_power
|
| 183 |
+
|
| 184 |
+
# We want to avoid taking the log of zero
|
| 185 |
+
# There are two options: either adding or clamping to a small value
|
| 186 |
+
if log_zero_guard_type not in ["add", "clamp"]:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"{self} received {log_zero_guard_type} for the "
|
| 189 |
+
f"log_zero_guard_type parameter. It must be either 'add' or "
|
| 190 |
+
f"'clamp'."
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.use_grads = use_grads
|
| 194 |
+
if not use_grads:
|
| 195 |
+
self.forward = torch.no_grad()(self.forward)
|
| 196 |
+
self._rng = random.Random() if rng is None else rng
|
| 197 |
+
self.nb_augmentation_prob = nb_augmentation_prob
|
| 198 |
+
if self.nb_augmentation_prob > 0.0:
|
| 199 |
+
if nb_max_freq >= sample_rate / 2:
|
| 200 |
+
self.nb_augmentation_prob = 0.0
|
| 201 |
+
else:
|
| 202 |
+
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
|
| 203 |
+
|
| 204 |
+
# log_zero_guard_value is the the small we want to use, we support
|
| 205 |
+
# an actual number, or "tiny", or "eps"
|
| 206 |
+
self.log_zero_guard_type = log_zero_guard_type
|
| 207 |
+
logging.debug(f"sr: {sample_rate}")
|
| 208 |
+
logging.debug(f"n_fft: {self.n_fft}")
|
| 209 |
+
logging.debug(f"win_length: {self.win_length}")
|
| 210 |
+
logging.debug(f"hop_length: {self.hop_length}")
|
| 211 |
+
logging.debug(f"n_mels: {nfilt}")
|
| 212 |
+
logging.debug(f"fmin: {lowfreq}")
|
| 213 |
+
logging.debug(f"fmax: {highfreq}")
|
| 214 |
+
logging.debug(f"using grads: {use_grads}")
|
| 215 |
+
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")
|
| 216 |
+
|
| 217 |
+
def log_zero_guard_value_fn(self, x):
|
| 218 |
+
if isinstance(self.log_zero_guard_value, str):
|
| 219 |
+
if self.log_zero_guard_value == "tiny":
|
| 220 |
+
return torch.finfo(x.dtype).tiny
|
| 221 |
+
elif self.log_zero_guard_value == "eps":
|
| 222 |
+
return torch.finfo(x.dtype).eps
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"{self} received {self.log_zero_guard_value} for the "
|
| 226 |
+
f"log_zero_guard_type parameter. It must be either a "
|
| 227 |
+
f"number, 'tiny', or 'eps'"
|
| 228 |
+
)
|
| 229 |
+
else:
|
| 230 |
+
return self.log_zero_guard_value
|
| 231 |
+
|
| 232 |
+
def get_seq_len(self, seq_len):
|
| 233 |
+
# Assuming that center is True is stft_pad_amount = 0
|
| 234 |
+
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
|
| 235 |
+
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
|
| 236 |
+
return seq_len.to(dtype=torch.long)
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def filter_banks(self):
|
| 240 |
+
return self.fb
|
| 241 |
+
|
| 242 |
+
def forward(self, x, seq_len, linear_spec=False):
|
| 243 |
+
seq_len = self.get_seq_len(seq_len.float())
|
| 244 |
+
|
| 245 |
+
if self.stft_pad_amount is not None:
|
| 246 |
+
x = torch.nn.functional.pad(
|
| 247 |
+
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
|
| 248 |
+
).squeeze(1)
|
| 249 |
+
|
| 250 |
+
# dither (only in training mode for eval determinism)
|
| 251 |
+
if self.training and self.dither > 0:
|
| 252 |
+
x += self.dither * torch.randn_like(x)
|
| 253 |
+
|
| 254 |
+
# do preemphasis
|
| 255 |
+
if self.preemph is not None:
|
| 256 |
+
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
|
| 257 |
+
|
| 258 |
+
# disable autocast to get full range of stft values
|
| 259 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 260 |
+
x = self.stft(x)
|
| 261 |
+
|
| 262 |
+
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
|
| 263 |
+
# guard is needed for sqrt if grads are passed through
|
| 264 |
+
guard = 0 if not self.use_grads else CONSTANT
|
| 265 |
+
x = torch.view_as_real(x)
|
| 266 |
+
x = torch.sqrt(x.pow(2).sum(-1) + guard)
|
| 267 |
+
|
| 268 |
+
if self.training and self.nb_augmentation_prob > 0.0:
|
| 269 |
+
for idx in range(x.shape[0]):
|
| 270 |
+
if self._rng.random() < self.nb_augmentation_prob:
|
| 271 |
+
x[idx, self._nb_max_fft_bin :, :] = 0.0
|
| 272 |
+
|
| 273 |
+
# get power spectrum
|
| 274 |
+
if self.mag_power != 1.0:
|
| 275 |
+
x = x.pow(self.mag_power)
|
| 276 |
+
|
| 277 |
+
# return plain spectrogram if required
|
| 278 |
+
if linear_spec:
|
| 279 |
+
return x, seq_len
|
| 280 |
+
|
| 281 |
+
# dot with filterbank energies
|
| 282 |
+
x = torch.matmul(self.fb.to(x.dtype), x)
|
| 283 |
+
# log features if required
|
| 284 |
+
if self.log:
|
| 285 |
+
if self.log_zero_guard_type == "add":
|
| 286 |
+
x = torch.log(x + self.log_zero_guard_value_fn(x))
|
| 287 |
+
elif self.log_zero_guard_type == "clamp":
|
| 288 |
+
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
|
| 289 |
+
else:
|
| 290 |
+
raise ValueError("log_zero_guard_type was not understood")
|
| 291 |
+
|
| 292 |
+
# frame splicing if required
|
| 293 |
+
if self.frame_splicing > 1:
|
| 294 |
+
x = splice_frames(x, self.frame_splicing)
|
| 295 |
+
|
| 296 |
+
# normalize if required
|
| 297 |
+
if self.normalize:
|
| 298 |
+
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
|
| 299 |
+
|
| 300 |
+
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
|
| 301 |
+
max_len = x.size(-1)
|
| 302 |
+
mask = torch.arange(max_len).to(x.device)
|
| 303 |
+
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
|
| 304 |
+
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
|
| 305 |
+
del mask
|
| 306 |
+
pad_to = self.pad_to
|
| 307 |
+
if pad_to == "max":
|
| 308 |
+
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
|
| 309 |
+
elif pad_to > 0:
|
| 310 |
+
pad_amt = x.size(-1) % pad_to
|
| 311 |
+
if pad_amt != 0:
|
| 312 |
+
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
|
| 313 |
+
return x, seq_len
|
xvector_sincnet.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import warnings
|
| 7 |
+
from asteroid_filterbanks import Encoder, ParamSincFB
|
| 8 |
+
|
| 9 |
+
def merge_dict(defaults: dict, custom: dict = None):
|
| 10 |
+
params = dict(defaults)
|
| 11 |
+
if custom is not None:
|
| 12 |
+
params.update(custom)
|
| 13 |
+
return params
|
| 14 |
+
|
| 15 |
+
class StatsPool(nn.Module):
|
| 16 |
+
"""Statistics pooling
|
| 17 |
+
Compute temporal mean and (unbiased) standard deviation
|
| 18 |
+
and returns their concatenation.
|
| 19 |
+
Reference
|
| 20 |
+
---------
|
| 21 |
+
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def forward(
|
| 25 |
+
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
"""Forward pass
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
sequences : (batch, channel, frames) torch.Tensor
|
| 31 |
+
Sequences.
|
| 32 |
+
weights : (batch, frames) torch.Tensor, optional
|
| 33 |
+
When provided, compute weighted mean and standard deviation.
|
| 34 |
+
Returns
|
| 35 |
+
-------
|
| 36 |
+
output : (batch, 2 * channel) torch.Tensor
|
| 37 |
+
Concatenation of mean and (unbiased) standard deviation.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if weights is None:
|
| 41 |
+
mean = sequences.mean(dim=2)
|
| 42 |
+
std = sequences.std(dim=2, unbiased=True)
|
| 43 |
+
|
| 44 |
+
else:
|
| 45 |
+
weights = weights.unsqueeze(dim=1)
|
| 46 |
+
# (batch, 1, frames)
|
| 47 |
+
|
| 48 |
+
num_frames = sequences.shape[2]
|
| 49 |
+
num_weights = weights.shape[2]
|
| 50 |
+
if num_frames != num_weights:
|
| 51 |
+
warnings.warn(
|
| 52 |
+
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
|
| 53 |
+
)
|
| 54 |
+
weights = F.interpolate(
|
| 55 |
+
weights, size=num_frames, mode="linear", align_corners=False
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
v1 = weights.sum(dim=2)
|
| 59 |
+
mean = torch.sum(sequences * weights, dim=2) / v1
|
| 60 |
+
|
| 61 |
+
dx2 = torch.square(sequences - mean.unsqueeze(2))
|
| 62 |
+
v2 = torch.square(weights).sum(dim=2)
|
| 63 |
+
|
| 64 |
+
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1)
|
| 65 |
+
std = torch.sqrt(var)
|
| 66 |
+
|
| 67 |
+
return torch.cat([mean, std], dim=1)
|
| 68 |
+
|
| 69 |
+
class SincNet(nn.Module):
|
| 70 |
+
def __init__(self, sample_rate: int = 16000, stride: int = 1):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
if sample_rate != 16000:
|
| 74 |
+
raise NotImplementedError("PyanNet only supports 16kHz audio for now.")
|
| 75 |
+
# TODO: add support for other sample rate. it should be enough to multiply
|
| 76 |
+
# kernel_size by (sample_rate / 16000). but this needs to be double-checked.
|
| 77 |
+
|
| 78 |
+
self.stride = stride
|
| 79 |
+
|
| 80 |
+
self.wav_norm1d = nn.InstanceNorm1d(1, affine=True)
|
| 81 |
+
|
| 82 |
+
self.conv1d = nn.ModuleList()
|
| 83 |
+
self.pool1d = nn.ModuleList()
|
| 84 |
+
self.norm1d = nn.ModuleList()
|
| 85 |
+
|
| 86 |
+
self.conv1d.append(
|
| 87 |
+
Encoder(
|
| 88 |
+
ParamSincFB(
|
| 89 |
+
80,
|
| 90 |
+
251,
|
| 91 |
+
stride=self.stride,
|
| 92 |
+
sample_rate=sample_rate,
|
| 93 |
+
min_low_hz=50,
|
| 94 |
+
min_band_hz=50,
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
| 99 |
+
self.norm1d.append(nn.InstanceNorm1d(80, affine=True))
|
| 100 |
+
|
| 101 |
+
self.conv1d.append(nn.Conv1d(80, 60, 5, stride=1))
|
| 102 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
| 103 |
+
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
|
| 104 |
+
|
| 105 |
+
self.conv1d.append(nn.Conv1d(60, 60, 5, stride=1))
|
| 106 |
+
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
|
| 107 |
+
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
|
| 108 |
+
|
| 109 |
+
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""Pass forward
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
waveforms : (batch, channel, sample)
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
outputs = self.wav_norm1d(waveforms)
|
| 117 |
+
|
| 118 |
+
for c, (conv1d, pool1d, norm1d) in enumerate(
|
| 119 |
+
zip(self.conv1d, self.pool1d, self.norm1d)
|
| 120 |
+
):
|
| 121 |
+
|
| 122 |
+
outputs = conv1d(outputs)
|
| 123 |
+
|
| 124 |
+
# https://github.com/mravanelli/SincNet/issues/4
|
| 125 |
+
if c == 0:
|
| 126 |
+
outputs = torch.abs(outputs)
|
| 127 |
+
|
| 128 |
+
outputs = F.leaky_relu(norm1d(pool1d(outputs)))
|
| 129 |
+
|
| 130 |
+
return outputs
|
| 131 |
+
|
| 132 |
+
class XVectorSincNet(nn.Module):
|
| 133 |
+
|
| 134 |
+
SINCNET_DEFAULTS = {"stride": 10}
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
sample_rate: int = 16000,
|
| 139 |
+
# num_channels: int = 1,
|
| 140 |
+
sincnet: dict = dict(
|
| 141 |
+
stride=10,
|
| 142 |
+
sample_rate=16000
|
| 143 |
+
),
|
| 144 |
+
dimension: int = 512,
|
| 145 |
+
# task: Optional[Task] = None,
|
| 146 |
+
):
|
| 147 |
+
super(XVectorSincNet, self).__init__()
|
| 148 |
+
|
| 149 |
+
sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet)
|
| 150 |
+
sincnet["sample_rate"] = sample_rate
|
| 151 |
+
|
| 152 |
+
# self.save_hyperparameters("sincnet", "dimension")
|
| 153 |
+
|
| 154 |
+
self.sincnet = SincNet(**sincnet)
|
| 155 |
+
in_channel = 60
|
| 156 |
+
|
| 157 |
+
self.tdnns = nn.ModuleList()
|
| 158 |
+
out_channels = [512, 512, 512, 512, 1500]
|
| 159 |
+
kernel_sizes = [5, 3, 3, 1, 1]
|
| 160 |
+
dilations = [1, 2, 3, 1, 1]
|
| 161 |
+
|
| 162 |
+
for out_channel, kernel_size, dilation in zip(
|
| 163 |
+
out_channels, kernel_sizes, dilations
|
| 164 |
+
):
|
| 165 |
+
self.tdnns.extend(
|
| 166 |
+
[
|
| 167 |
+
nn.Conv1d(
|
| 168 |
+
in_channels=in_channel,
|
| 169 |
+
out_channels=out_channel,
|
| 170 |
+
kernel_size=kernel_size,
|
| 171 |
+
dilation=dilation,
|
| 172 |
+
),
|
| 173 |
+
nn.LeakyReLU(),
|
| 174 |
+
nn.BatchNorm1d(out_channel),
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
in_channel = out_channel
|
| 178 |
+
|
| 179 |
+
self.stats_pool = StatsPool()
|
| 180 |
+
|
| 181 |
+
self.embedding = nn.Linear(in_channel * 2, dimension)
|
| 182 |
+
|
| 183 |
+
def forward(
|
| 184 |
+
self, waveforms: torch.Tensor, weights: torch.Tensor = None
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Parameters
|
| 188 |
+
----------
|
| 189 |
+
waveforms : torch.Tensor
|
| 190 |
+
Batch of waveforms with shape (batch, channel, sample)
|
| 191 |
+
weights : torch.Tensor, optional
|
| 192 |
+
Batch of weights with shape (batch, frame).
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
outputs = self.sincnet(waveforms).squeeze(dim=1)
|
| 196 |
+
for tdnn in self.tdnns:
|
| 197 |
+
outputs = tdnn(outputs)
|
| 198 |
+
outputs = self.stats_pool(outputs, weights=weights)
|
| 199 |
+
return self.embedding(outputs)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
""" Load model
|
| 203 |
+
|
| 204 |
+
def cal_xvector_sincnet_embedding(xvector_model, ref_wav, max_length=5, sr=16000):
|
| 205 |
+
wavs = []
|
| 206 |
+
for i in range(0, len(ref_wav), max_length*sr):
|
| 207 |
+
wav = ref_wav[i:i + max_length*sr]
|
| 208 |
+
wav = np.concatenate([wav, np.zeros(max(0, max_length * sr - len(wav)))])
|
| 209 |
+
wavs.append(wav)
|
| 210 |
+
wavs = torch.from_numpy(np.stack(wavs))
|
| 211 |
+
if use_gpu:
|
| 212 |
+
wavs = wavs.cuda()
|
| 213 |
+
embed = xvector_model(wavs.unsqueeze(1).float())
|
| 214 |
+
return torch.mean(embed, dim=0).detach().cpu()
|
| 215 |
+
|
| 216 |
+
xvector_model = XVectorSincNet()
|
| 217 |
+
model_file = "model-bin/speaker_embedding/xvector_sincnet.pt"
|
| 218 |
+
meta = torch.load(model_file, map_location='cpu')['state_dict']
|
| 219 |
+
print('load_xvector_sincnet_model', xvector_model.load_state_dict(meta, strict=False))
|
| 220 |
+
xvector_model = xvector_model.eval()
|
| 221 |
+
for param in xvector_model.parameters():
|
| 222 |
+
param.requires_grad = False
|
| 223 |
+
"""
|