import torch import soundfile as sf import numpy as np import argparse import os import yaml import julius import sys currentdir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.dirname(currentdir)) from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder from modules.loss import AudioFeatureLoss, Loss def convert_audio(wav: torch.Tensor, from_rate: float, to_rate: float, to_channels: int) -> torch.Tensor: """Convert audio to new sample rate and number of audio channels. """ wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) wav = convert_audio_channels(wav, to_channels) return wav class MasteringStyleTransfer: def __init__(self, args): self.args = args self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load models self.effects_encoder = self.load_effects_encoder() self.mastering_converter = self.load_mastering_converter() def load_effects_encoder(self): effects_encoder = Effects_Encoder(self.args.cfg_enc) reload_weights(effects_encoder, self.args.encoder_path, self.device) effects_encoder.to(self.device) effects_encoder.eval() return effects_encoder def load_mastering_converter(self): mastering_converter = Dasp_Mastering_Style_Transfer(num_features=2048, sample_rate=self.args.sample_rate, tgt_fx_names=['eq', 'distortion', 'multiband_comp', 'gain', 'imager', 'limiter'], model_type='tcn', config=self.args.cfg_converter, batch_size=1) reload_weights(mastering_converter, self.args.model_path, self.device) mastering_converter.to(self.device) mastering_converter.eval() return mastering_converter def get_reference_embedding(self, reference_tensor): with torch.no_grad(): reference_feature = self.effects_encoder(reference_tensor) return reference_feature def mastering_style_transfer(self, input_tensor, reference_feature): with torch.no_grad(): output_audio = self.mastering_converter(input_tensor, reference_feature) predicted_params = self.mastering_converter.get_last_predicted_params() return output_audio, predicted_params def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature): fit_embedding = torch.nn.Parameter(initial_reference_feature) optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate']) af_loss = AudioFeatureLoss( weights=ito_config['af_weights'], sample_rate=ito_config['sample_rate'], stem_separation=False, use_clap=False ) min_loss = float('inf') min_loss_step = 0 min_loss_output = None min_loss_params = None min_loss_embedding = None loss_history = [] divergence_counter = 0 ito_log = [] for step in range(ito_config['num_steps']): optimizer.zero_grad() output_audio = self.mastering_converter(input_tensor, fit_embedding) current_params = self.mastering_converter.get_last_predicted_params() losses = af_loss(output_audio, reference_tensor) total_loss = sum(losses.values()) loss_history.append(total_loss.item()) if total_loss < min_loss: min_loss = total_loss.item() min_loss_step = step min_loss_output = output_audio.detach() min_loss_params = current_params min_loss_embedding = fit_embedding.detach().clone() # Check for divergence if len(loss_history) > 10 and total_loss > loss_history[-11]: divergence_counter += 1 else: divergence_counter = 0 # Log top 5 parameter differences if step == 0: initial_params = current_params top_5_diff = self.get_top_n_diff_string(initial_params, current_params, top_n=5) log_entry = f"Step {step + 1}, Loss: {total_loss.item():.4f}\n{top_5_diff}\n" if divergence_counter >= 10: print(f"Optimization stopped early due to divergence at step {step}") break total_loss.backward() optimizer.step() yield log_entry, output_audio.detach(), current_params, step + 1, total_loss.item() return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1 def preprocess_audio(self, audio, target_sample_rate=44100): sample_rate, data = audio # Normalize audio to -1 to 1 range if data.dtype == np.int16: data = data.astype(np.float32) / 32768.0 elif data.dtype == np.float32: data = np.clip(data, -1.0, 1.0) else: raise ValueError(f"Unsupported audio data type: {data.dtype}") # Ensure stereo channels if data.ndim == 1: data = np.stack([data, data]) elif data.ndim == 2: if data.shape[0] == 2: pass # Already in correct shape elif data.shape[1] == 2: data = data.T else: data = np.stack([data[:, 0], data[:, 0]]) # Duplicate mono channel else: raise ValueError(f"Unsupported audio shape: {data.shape}") # Convert to torch tensor data_tensor = torch.FloatTensor(data).unsqueeze(0) # Resample if necessary if sample_rate != target_sample_rate: data_tensor = julius.resample_frac(data_tensor, sample_rate, target_sample_rate) return data_tensor.to(self.device) def process_audio(self, input_audio, reference_audio, ito_reference_audio): input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate) reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate) ito_reference_tensor = self.preprocess_audio(ito_reference_audio, self.args.sample_rate) reference_feature = self.get_reference_embedding(reference_tensor) output_audio, predicted_params = self.mastering_style_transfer(input_tensor, reference_feature) return output_audio, predicted_params, self.args.sample_rate def print_param_difference(self, initial_params, ito_params): all_diffs = [] print("\nAll parameter differences:") for fx_name in initial_params.keys(): print(f"\n{fx_name.upper()}:") if isinstance(initial_params[fx_name], dict): for param_name in initial_params[fx_name].keys(): initial_value = initial_params[fx_name][param_name] ito_value = ito_params[fx_name][param_name] # Calculate normalized difference param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name] normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0])) all_diffs.append((fx_name, param_name, initial_value, ito_value, normalized_diff)) print(f" {param_name}:") print(f" Initial: {initial_value.item():.4f}") print(f" ITO: {ito_value.item():.4f}") print(f" Normalized Diff: {normalized_diff.item():.4f}") else: initial_value = initial_params[fx_name] ito_value = ito_params[fx_name] # For 'imager', assume range is 0 to 1 normalized_diff = abs(ito_value - initial_value) all_diffs.append((fx_name, 'width', initial_value, ito_value, normalized_diff)) print(f" width:") print(f" Initial: {initial_value.item():.4f}") print(f" ITO: {ito_value.item():.4f}") print(f" Normalized Diff: {normalized_diff.item():.4f}") # Sort differences by normalized difference and get top 10 top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10] print("\nTop 10 parameter differences (sorted by normalized difference):") for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs: print(f"{fx_name.upper()} - {param_name}:") print(f" Initial: {initial_value.item():.4f}") print(f" ITO: {ito_value.item():.4f}") print(f" Normalized Diff: {normalized_diff.item():.4f}") print() def print_predicted_params(self, predicted_params): if predicted_params is None: print("No predicted parameters available.") return print("Predicted Parameters:") for fx_name, fx_params in predicted_params.items(): print(f"\n{fx_name.upper()}:") if isinstance(fx_params, dict): for param_name, param_value in fx_params.items(): if isinstance(param_value, torch.Tensor): param_value = param_value.detach().cpu().numpy() print(f" {param_name}: {param_value}") elif isinstance(fx_params, torch.Tensor): param_value = fx_params.detach().cpu().numpy() print(f" {param_value}") else: print(f" {fx_params}") def get_param_output_string(self, params): if params is None: return "No parameters available" output = [] for fx_name, fx_params in params.items(): output.append(f"{fx_name.upper()}:") if isinstance(fx_params, dict): for param_name, param_value in fx_params.items(): if isinstance(param_value, torch.Tensor): param_value = param_value.item() output.append(f" {param_name}: {param_value:.4f}") elif isinstance(fx_params, torch.Tensor): output.append(f" {fx_params.item():.4f}") else: output.append(f" {fx_params:.4f}") return "\n".join(output) def get_top_n_diff_string(self, initial_params, ito_params, top_n=5): if initial_params is None or ito_params is None: return "Cannot compare parameters" all_diffs = [] for fx_name in initial_params.keys(): if isinstance(initial_params[fx_name], dict): for param_name in initial_params[fx_name].keys(): initial_value = initial_params[fx_name][param_name] ito_value = ito_params[fx_name][param_name] param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name] normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0])) all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item())) else: initial_value = initial_params[fx_name] ito_value = ito_params[fx_name] normalized_diff = abs(ito_value - initial_value) all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item())) top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:top_n] output = ["Top 10 parameter differences (sorted by normalized difference):"] for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs: output.append(f"{fx_name.upper()} - {param_name}:") output.append(f" Initial: {initial_value:.4f}") output.append(f" ITO: {ito_value:.4f}") output.append(f" Normalized Diff: {normalized_diff:.4f}") output.append("") return "\n".join(output) def reload_weights(model, ckpt_path, device): checkpoint = torch.load(ckpt_path, map_location=device) from collections import OrderedDict new_state_dict = OrderedDict() for k, v in checkpoint["model"].items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=False)