pengdaqian
fix
62e9d65
raw
history blame contribute delete
991 Bytes
from pathlib import Path
import torch
import typer
from utils import sound_split
app = typer.Typer()
@app.command()
def model_summary() -> None:
from unet import UNet
net = UNet()
print(net)
@app.command()
def test() -> None:
from unet import UNet
batch_size = 5
n_channels = 2
x = torch.randn(batch_size, n_channels, 512, 128)
print(x.shape)
net = UNet(in_channels=n_channels)
y = net.forward(x)
print(y.shape)
@app.command()
def split(
model_path: str = "models/2stems/model",
input: str = "data/audio_example.mp3",
output_dir: str = "output",
offset: float = 0,
duration: float = 30,
write_src: bool = False,
) -> None:
from splitter import Splitter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
splitter_model = Splitter.from_pretrained(model_path).to(device).eval()
sound_split(splitter_model, input, output_dir, write_src)
if __name__ == "__main__":
app()