|
|
|
|
|
|
|
|
|
|
|
|
|
import random
|
|
|
|
import torch
|
|
|
|
from audiocraft.losses import (
|
|
MelSpectrogramL1Loss,
|
|
MultiScaleMelSpectrogramLoss,
|
|
MRSTFTLoss,
|
|
SISNR,
|
|
STFTLoss,
|
|
)
|
|
from audiocraft.losses.loudnessloss import TFLoudnessRatio
|
|
from audiocraft.losses.wmloss import WMMbLoss
|
|
from tests.common_utils.wav_utils import get_white_noise
|
|
|
|
|
|
def test_mel_l1_loss():
|
|
N, C, T = 2, 2, random.randrange(1000, 100_000)
|
|
t1 = torch.randn(N, C, T)
|
|
t2 = torch.randn(N, C, T)
|
|
|
|
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050)
|
|
loss = mel_l1(t1, t2)
|
|
loss_same = mel_l1(t1, t1)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert isinstance(loss_same, torch.Tensor)
|
|
assert loss_same.item() == 0.0
|
|
|
|
|
|
def test_msspec_loss():
|
|
N, C, T = 2, 2, random.randrange(1000, 100_000)
|
|
t1 = torch.randn(N, C, T)
|
|
t2 = torch.randn(N, C, T)
|
|
|
|
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050)
|
|
loss = msspec(t1, t2)
|
|
loss_same = msspec(t1, t1)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
assert isinstance(loss_same, torch.Tensor)
|
|
assert loss_same.item() == 0.0
|
|
|
|
|
|
def test_mrstft_loss():
|
|
N, C, T = 2, 2, random.randrange(1000, 100_000)
|
|
t1 = torch.randn(N, C, T)
|
|
t2 = torch.randn(N, C, T)
|
|
|
|
mrstft = MRSTFTLoss()
|
|
loss = mrstft(t1, t2)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
def test_sisnr_loss():
|
|
N, C, T = 2, 2, random.randrange(1000, 100_000)
|
|
t1 = torch.randn(N, C, T)
|
|
t2 = torch.randn(N, C, T)
|
|
|
|
sisnr = SISNR()
|
|
loss = sisnr(t1, t2)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
def test_stft_loss():
|
|
N, C, T = 2, 2, random.randrange(1000, 100_000)
|
|
t1 = torch.randn(N, C, T)
|
|
t2 = torch.randn(N, C, T)
|
|
|
|
mrstft = STFTLoss()
|
|
loss = mrstft(t1, t2)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
def test_wm_loss():
|
|
N, nbits, T = 2, 16, random.randrange(1000, 100_000)
|
|
positive = torch.randn(N, 2 + nbits, T)
|
|
t2 = torch.randn(N, 1, T)
|
|
message = torch.randn(N, nbits)
|
|
|
|
wmloss = WMMbLoss(0.3, "mse")
|
|
loss = wmloss(positive, None, t2, message)
|
|
|
|
assert isinstance(loss, torch.Tensor)
|
|
|
|
|
|
def test_loudness_loss():
|
|
sr = 16_000
|
|
duration = 1.0
|
|
wav = get_white_noise(1, int(sr * duration)).unsqueeze(0)
|
|
tflrloss = TFLoudnessRatio(sample_rate=sr, n_bands=1)
|
|
|
|
loss = tflrloss(wav, wav)
|
|
assert isinstance(loss, torch.Tensor)
|
|
|