import os from datetime import datetime from pathlib import Path import torch import typer from accelerate import Accelerator from accelerate.utils import LoggerType from torch import Tensor from torch.optim import AdamW # from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader from tqdm import tqdm from data import MusdbDataset from splitter import Splitter DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False) app = typer.Typer(pretty_exceptions_show_locals=False) def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor: """ masked_target (Tensor): a masked STFT generated by applying a net's estimated mask for source S to the ground truth STFT for source S original (Tensor): an original input mixture """ square_difference = torch.square(masked_target - original) loss_value = torch.mean(square_difference) return loss_value @app.command() def train( dataset: str = "data/musdb18-wav", output_dir: str = None, fp16: bool = False, cpu: bool = True, max_steps: int = 100, num_train_epochs: int = 1, per_device_train_batch_size: int = 1, effective_batch_size: int = 4, max_grad_norm: float = 0.0, ) -> None: if not output_dir: now_str = datetime.now().strftime("%Y%m%d-%H%M%S") output_dir = f"experiments/{now_str}" output_dir = Path(output_dir) logging_dir = output_dir / "tracker_logs" accelerator = Accelerator( fp16=fp16, cpu=cpu, logging_dir=logging_dir, log_with=[LoggerType.TENSORBOARD], ) accelerator.init_trackers(logging_dir / "run") train_dataset = MusdbDataset(root=dataset, is_train=True) train_dataloader = DataLoader( train_dataset, shuffle=True, batch_size=per_device_train_batch_size, ) model = Splitter(stem_names=[s for s in train_dataset.targets]) optimizer = AdamW( model.parameters(), lr=1e-3, eps=1e-8, ) model, optimizer, train_dataloader = accelerator.prepare( model, optimizer, train_dataloader ) num_train_steps = ( max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs ) accelerator.print(f"Num train steps: {num_train_steps}") step_batch_size = per_device_train_batch_size * accelerator.num_processes gradient_accumulation_steps = max( 1, effective_batch_size // step_batch_size, ) accelerator.print( f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}" ) global_step = 0 while global_step < num_train_steps: accelerator.wait_for_everyone() # accelerator.print(f"global step: {global_step}") # accelerator.print("running train...") model.train() batch_iterator = tqdm( train_dataloader, desc="Batch", disable=((not accelerator.is_local_main_process) or DISABLE_TQDM), ) for batch_idx, batch in enumerate(batch_iterator): assert per_device_train_batch_size == 1, "For now limit to 1." x_wav, y_target_wavs = batch predictions = model(x_wav) stem_losses = [] for name, masked_stft in predictions.items(): target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze()) loss = spectrogram_loss( masked_target=masked_stft, original=target_stft, ) stem_losses.append(loss) accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step) total_loss = ( torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps ) accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}") accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step) accelerator.backward(total_loss) if (batch_idx + 1) % gradient_accumulation_steps == 0: if max_grad_norm > 0: accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() optimizer.zero_grad() global_step += 1 accelerator.wait_for_everyone() accelerator.end_training() accelerator.print(f"Saving model to {output_dir}...") unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( output_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model), ) accelerator.wait_for_everyone() accelerator.print("DONE!") if __name__ == "__main__": app()