File size: 991 Bytes
62e9d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from pathlib import Path

import torch
import typer
from utils import sound_split

app = typer.Typer()


@app.command()
def model_summary() -> None:
    from unet import UNet

    net = UNet()
    print(net)


@app.command()
def test() -> None:
    from unet import UNet

    batch_size = 5
    n_channels = 2
    x = torch.randn(batch_size, n_channels, 512, 128)
    print(x.shape)
    net = UNet(in_channels=n_channels)
    y = net.forward(x)
    print(y.shape)


@app.command()
def split(
    model_path: str = "models/2stems/model",
    input: str = "data/audio_example.mp3",
    output_dir: str = "output",
    offset: float = 0,
    duration: float = 30,
    write_src: bool = False,
) -> None:

    from splitter import Splitter

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    splitter_model = Splitter.from_pretrained(model_path).to(device).eval()

    sound_split(splitter_model, input, output_dir, write_src)


if __name__ == "__main__":
    app()