File size: 3,863 Bytes
bcdb559 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from functools import reduce
from inspect import isfunction
from math import ceil, floor, log2, pi
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Generator, Tensor
from typing_extensions import TypeGuard
T = TypeVar("T")
def exists(val: Optional[T]) -> TypeGuard[T]:
return val is not None
def iff(condition: bool, value: T) -> Optional[T]:
return value if condition else None
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
return isinstance(obj, list) or isinstance(obj, tuple)
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
if isinstance(val, tuple):
return list(val)
if isinstance(val, list):
return val
return [val] # type: ignore
def prod(vals: Sequence[int]) -> int:
return reduce(lambda x, y: x * y, vals)
def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2 ** z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
"""
Kwargs Utils
"""
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
def prefix_dict(prefix: str, d: Dict) -> Dict:
return {prefix + str(k): v for k, v in d.items()}
"""
DSP Utils
"""
def resample(
waveforms: Tensor,
factor_in: int,
factor_out: int,
rolloff: float = 0.99,
lowpass_filter_width: int = 6,
) -> Tensor:
"""Resamples a waveform using sinc interpolation, adapted from torchaudio"""
b, _, length = waveforms.shape
length_target = int(factor_out * length / factor_in)
d = dict(device=waveforms.device, dtype=waveforms.dtype)
base_factor = min(factor_in, factor_out) * rolloff
width = ceil(lowpass_filter_width * factor_in / base_factor)
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi
window = torch.cos(t / lowpass_filter_width / 2) ** 2
scale = base_factor / factor_in
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
kernels *= window * scale
waveforms = rearrange(waveforms, "b c t -> (b c) t")
waveforms = F.pad(waveforms, (width, width + factor_in))
resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in)
resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b)
return resampled[..., :length_target]
def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=factor, factor_out=1, **kwargs)
def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
""" Torch Utils """
def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs):
"""randn_like that supports generator"""
return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor)
|