File size: 1,187 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from pathlib import Path
from typing import Callable

from torch import Tensor


def walk_paths(root, suffix):
    for path in Path(root).iterdir():
        if path.is_dir():
            yield from walk_paths(path, suffix)
        elif path.suffix == suffix:
            yield path


def rglob_audio_files(path: Path):
    return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))


def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
    """
    Args:
        fg: (b, t)
        bg: (b, t)
    """
    assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
    fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
    bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)

    fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
    bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)

    fg = fg / (fg_energy + eps).sqrt()
    bg = bg / (bg_energy + eps).sqrt()

    if callable(alpha):
        alpha = alpha()

    assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"

    mx = alpha * fg + (1 - alpha) * bg
    mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)

    return mx