Spaces:
Runtime error
Runtime error
File size: 991 Bytes
62e9d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from pathlib import Path
import torch
import typer
from utils import sound_split
app = typer.Typer()
@app.command()
def model_summary() -> None:
from unet import UNet
net = UNet()
print(net)
@app.command()
def test() -> None:
from unet import UNet
batch_size = 5
n_channels = 2
x = torch.randn(batch_size, n_channels, 512, 128)
print(x.shape)
net = UNet(in_channels=n_channels)
y = net.forward(x)
print(y.shape)
@app.command()
def split(
model_path: str = "models/2stems/model",
input: str = "data/audio_example.mp3",
output_dir: str = "output",
offset: float = 0,
duration: float = 30,
write_src: bool = False,
) -> None:
from splitter import Splitter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
splitter_model = Splitter.from_pretrained(model_path).to(device).eval()
sound_split(splitter_model, input, output_dir, write_src)
if __name__ == "__main__":
app()
|