Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import torch | |
import typer | |
from utils import sound_split | |
app = typer.Typer() | |
def model_summary() -> None: | |
from unet import UNet | |
net = UNet() | |
print(net) | |
def test() -> None: | |
from unet import UNet | |
batch_size = 5 | |
n_channels = 2 | |
x = torch.randn(batch_size, n_channels, 512, 128) | |
print(x.shape) | |
net = UNet(in_channels=n_channels) | |
y = net.forward(x) | |
print(y.shape) | |
def split( | |
model_path: str = "models/2stems/model", | |
input: str = "data/audio_example.mp3", | |
output_dir: str = "output", | |
offset: float = 0, | |
duration: float = 30, | |
write_src: bool = False, | |
) -> None: | |
from splitter import Splitter | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
splitter_model = Splitter.from_pretrained(model_path).to(device).eval() | |
sound_split(splitter_model, input, output_dir, write_src) | |
if __name__ == "__main__": | |
app() | |