Spaces:
Runtime error
Runtime error
File size: 1,239 Bytes
a8c39f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
_torch_stft = torch.stft
def z_stft(
audio: torch.Tensor,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: torch.Tensor = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
sd = audio.device
return _torch_stft(
audio.to("cpu"),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window.to("cpu"),
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
).to(sd)
def z_jit(f, *_, **__):
f.graph = torch._C.Graph()
return f
# hijacks
torch.stft = z_stft
torch.jit.script = z_jit
# disabling unsupported cudnn
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
|