from pathlib import Path import torch import typer from utils import sound_split app = typer.Typer() @app.command() def model_summary() -> None: from unet import UNet net = UNet() print(net) @app.command() def test() -> None: from unet import UNet batch_size = 5 n_channels = 2 x = torch.randn(batch_size, n_channels, 512, 128) print(x.shape) net = UNet(in_channels=n_channels) y = net.forward(x) print(y.shape) @app.command() def split( model_path: str = "models/2stems/model", input: str = "data/audio_example.mp3", output_dir: str = "output", offset: float = 0, duration: float = 30, write_src: bool = False, ) -> None: from splitter import Splitter device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") splitter_model = Splitter.from_pretrained(model_path).to(device).eval() sound_split(splitter_model, input, output_dir, write_src) if __name__ == "__main__": app()