import torch
import soundfile as sf
import numpy as np
import argparse
import os
import yaml
import julius

import sys
currentdir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.dirname(currentdir))
from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
from modules.loss import AudioFeatureLoss, Loss, CLAPFeatureLoss
from modules.data_normalization import Audio_Effects_Normalizer


def convert_audio(wav: torch.Tensor, from_rate: float,
                  to_rate: float, to_channels: int) -> torch.Tensor:
    """Convert audio to new sample rate and number of audio channels.
    """
    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
    wav = convert_audio_channels(wav, to_channels)
    return wav

class MasteringStyleTransfer:
    def __init__(self, args):
        self.args = args
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load models
        self.effects_encoder = self.load_effects_encoder()
        self.mastering_converter = self.load_mastering_converter()

        self.fx_normalizer = Audio_Effects_Normalizer(precomputed_feature_path=args.fx_norm_feature_path, \
                                                        STEMS=['mixture'], \
                                                        EFFECTS=['eq', 'imager', 'loudness'])
        # Loss functions
        self.clap_loss = CLAPFeatureLoss()

    def load_effects_encoder(self):
        effects_encoder = Effects_Encoder(self.args.cfg_enc)
        reload_weights(effects_encoder, self.args.encoder_path, self.device)
        effects_encoder.to(self.device)
        effects_encoder.eval()
        return effects_encoder

    def load_mastering_converter(self):
        mastering_converter = Dasp_Mastering_Style_Transfer(num_features=2048,
                                                            sample_rate=self.args.sample_rate,
                                                            tgt_fx_names=['eq', 'distortion', 'multiband_comp', 'gain', 'imager', 'limiter'],
                                                            model_type='tcn',
                                                            config=self.args.cfg_converter,
                                                            batch_size=1)
        reload_weights(mastering_converter, self.args.model_path, self.device)
        mastering_converter.to(self.device)
        mastering_converter.eval()
        return mastering_converter

    def get_reference_embedding(self, reference_tensor):
        with torch.no_grad():
            reference_feature = self.effects_encoder(reference_tensor)
        return reference_feature

    def mastering_style_transfer(self, input_tensor, reference_feature):
        with torch.no_grad():
            output_audio = self.mastering_converter(input_tensor, reference_feature)
            predicted_params = self.mastering_converter.get_last_predicted_params()
        return output_audio, predicted_params

    def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
        fit_embedding = torch.nn.Parameter(initial_reference_feature, requires_grad=True)
        optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])

        min_loss = float('inf')
        min_loss_step = 0
        all_results = []

        af_loss = AudioFeatureLoss(
            weights=ito_config['af_weights'],
            sample_rate=ito_config['sample_rate'],
            stem_separation=False,
            use_clap=False
        )

        for step in range(ito_config['num_steps']):
            optimizer.zero_grad()

            output_audio = self.mastering_converter(input_tensor, fit_embedding)
            current_params = self.mastering_converter.get_last_predicted_params()

            # Compute loss
            if ito_config['loss_function'] == 'AudioFeatureLoss':
                losses = af_loss(output_audio, reference_tensor)
                total_loss = sum(losses.values())
            elif ito_config['loss_function'] == 'CLAPFeatureLoss':
                if ito_config['clap_target_type'] == 'Audio':
                    target = reference_tensor
                else:
                    target = ito_config['clap_text_prompt']
                total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])

            if total_loss < min_loss:
                min_loss = total_loss.item()
                min_loss_step = step

            # Log top 5 parameter differences
            if step == 0:
                initial_params = current_params
            top_5_diff = self.get_top_n_diff_string(initial_params, current_params, top_n=5)
            log_entry = f"Step {step + 1}\n   Loss: {total_loss.item():.4f}\n{top_5_diff}\n"

            all_results.append({
                'step': step + 1,
                'loss': total_loss.item(),
                'audio': output_audio.detach().cpu().numpy(),
                'params': current_params,
                'log': log_entry
            })

            total_loss.backward()
            optimizer.step()

        return all_results, min_loss_step

    def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
        sample_rate, data = audio

        # Normalize audio to -1 to 1 range
        if data.dtype == np.int16:
            data = data.astype(np.float32) / 32768.0
        elif data.dtype == np.float32:
            data = np.clip(data, -1.0, 1.0)
        else:
            raise ValueError(f"Unsupported audio data type: {data.dtype}")

        # Ensure stereo channels
        if data.ndim == 1:
            data = np.stack([data, data])
        elif data.ndim == 2:
            if data.shape[0] == 2:
                pass  # Already in correct shape
            elif data.shape[1] == 2:
                data = data.T
            else:
                data = np.stack([data[:, 0], data[:, 0]])  # Duplicate mono channel
        else:
            raise ValueError(f"Unsupported audio shape: {data.shape}")

        # Resample if necessary
        if sample_rate != target_sample_rate:
            data = julius.resample_frac(torch.from_numpy(data), sample_rate, target_sample_rate).numpy()

        # Apply fx normalization for input audio during mastering style transfer
        if normalize:
            data = self.fx_normalizer.normalize_audio(data.T, 'mixture').T

        # Convert to torch tensor
        data_tensor = torch.FloatTensor(data).unsqueeze(0)

        return data_tensor.to(self.device)

    def process_audio(self, input_audio, reference_audio):
        input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate, normalize=True)
        reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate)

        reference_feature = self.get_reference_embedding(reference_tensor)

        output_audio, predicted_params = self.mastering_style_transfer(input_tensor, reference_feature)

        return output_audio, predicted_params, self.args.sample_rate, input_tensor

    def get_param_output_string(self, params):
        if params is None:
            return "No parameters available"
        
        param_mapper = {
            'eq': {
                'low_shelf_gain_db': ('Low Shelf Gain', 'dB', -20, 20),
                'low_shelf_cutoff_freq': ('Low Shelf Cutoff', 'Hz', 20, 2000),
                'low_shelf_q_factor': ('Low Shelf Q', '', 0.1, 5.0),
                'band0_gain_db': ('Low-Mid Band Gain', 'dB', -20, 20),
                'band0_cutoff_freq': ('Low-Mid Band Frequency', 'Hz', 80, 2000),
                'band0_q_factor': ('Low-Mid Band Q', '', 0.1, 5.0),
                'band1_gain_db': ('Mid Band Gain', 'dB', -20, 20),
                'band1_cutoff_freq': ('Mid Band Frequency', 'Hz', 2000, 8000),
                'band1_q_factor': ('Mid Band Q', '', 0.1, 5.0),
                'band2_gain_db': ('High-Mid Band Gain', 'dB', -20, 20),
                'band2_cutoff_freq': ('High-Mid Band Frequency', 'Hz', 8000, 12000),
                'band2_q_factor': ('High-Mid Band Q', '', 0.1, 5.0),
                'band3_gain_db': ('High Band Gain', 'dB', -20, 20),
                'band3_cutoff_freq': ('High Band Frequency', 'Hz', 12000, 20000),
                'band3_q_factor': ('High Band Q', '', 0.1, 5.0),
                'high_shelf_gain_db': ('High Shelf Gain', 'dB', -20, 20),
                'high_shelf_cutoff_freq': ('High Shelf Cutoff', 'Hz', 4000, 20000),
                'high_shelf_q_factor': ('High Shelf Q', '', 0.1, 5.0),
            },
            'distortion': {
                'drive_db': ('Drive', 'dB', 0, 8),
                'parallel_weight_factor': ('Dry/Wet Mix', '%', 0, 100),
            },
            'multiband_comp': {
                'low_cutoff': ('Low/Mid Crossover', 'Hz', 20, 1000),
                'high_cutoff': ('Mid/High Crossover', 'Hz', 1000, 20000),
                'parallel_weight_factor': ('Dry/Wet Mix', '%', 0, 100),
                'low_shelf_comp_thresh': ('Low Band Comp Threshold', 'dB', -60, 0),
                'low_shelf_comp_ratio': ('Low Band Comp Ratio', ': 1', 1, 20),
                'low_shelf_exp_thresh': ('Low Band Exp Threshold', 'dB', -60, 0),
                'low_shelf_exp_ratio': ('Low Band Exp Ratio', ': 1', 1, 20),
                'low_shelf_at': ('Low Band Attack Time', 'ms', 5, 100),
                'low_shelf_rt': ('Low Band Release Time', 'ms', 5, 100),
                'mid_band_comp_thresh': ('Mid Band Comp Threshold', 'dB', -60, 0),
                'mid_band_comp_ratio': ('Mid Band Comp Ratio', ': 1', 1, 20),
                'mid_band_exp_thresh': ('Mid Band Exp Threshold', 'dB', -60, 0),
                'mid_band_exp_ratio': ('Mid Band Exp Ratio', ': 1', 0, 1),
                'mid_band_at': ('Mid Band Attack Time', 'ms', 5, 100),
                'mid_band_rt': ('Mid Band Release Time', 'ms', 5, 100),
                'high_shelf_comp_thresh': ('High Band Comp Threshold', 'dB', -60, 0),
                'high_shelf_comp_ratio': ('High Band Comp Ratio', ': 1', 1, 20),
                'high_shelf_exp_thresh': ('High Band Exp Threshold', 'dB', -60, 0),
                'high_shelf_exp_ratio': ('High Band Exp Ratio', ': 1', 1, 20),
                'high_shelf_at': ('High Band Attack Time', 'ms', 5, 100),
                'high_shelf_rt': ('High Band Release Time', 'ms', 5, 100),
            },
            'gain': {
                'gain_db': ('Output Gain', 'dB', -24, 24),
            },
            'imager': {
                'width': ('Stereo Width', '', 0, 1),
            },
            'limiter': {
                'threshold': ('Threshold', 'dB', -60, 0),
                'at': ('Attack Time', 'ms', 5, 100),
                'rt': ('Release Time', 'ms', 5, 100),
            },
        }
        
        output = []
        for fx_name, fx_params in params.items():
            output.append(f"{fx_name.upper()}:")
            if isinstance(fx_params, dict):
                for param_name, param_value in fx_params.items():
                    if isinstance(param_value, torch.Tensor):
                        param_value = param_value.item()
                    
                    if fx_name in param_mapper and param_name in param_mapper[fx_name]:
                        friendly_name, unit, min_val, max_val = param_mapper[fx_name][param_name]
                        if unit=='%':
                            param_value = param_value * 100
                        current_content = f"  {friendly_name}: {param_value:.2f} {unit}"
                        if param_name=='mid_band_exp_ratio':
                            current_content += f" (Range: {min_val}-{max_val})"
                        output.append(current_content)
                    else:
                        output.append(f"  {param_name}: {param_value:.2f}")
            else:
                # stereo imager
                width_percentage = fx_params.item() * 200
                output.append(f"  Stereo Width: {width_percentage:.2f}% (Range: 0-200%)")
    
        return "\n".join(output)

    def get_top_n_diff_string(self, initial_params, ito_params, top_n=5):
        if initial_params is None or ito_params is None:
            return "Cannot compare parameters"
        
        all_diffs = []
        for fx_name in initial_params.keys():
            if isinstance(initial_params[fx_name], dict):
                for param_name in initial_params[fx_name].keys():
                    initial_value = initial_params[fx_name][param_name]
                    ito_value = ito_params[fx_name][param_name]
                    
                    param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name]
                    normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0]))
                    
                    all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item()))
            else:
                initial_value = initial_params[fx_name]
                ito_value = ito_params[fx_name]
                normalized_diff = abs(ito_value - initial_value)
                all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item()))

        top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:top_n]
        
        output = [f"   Top {top_n} parameter differences (initial / ITO / normalized diff):"]
        for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs:
            output.append(f"      {fx_name.upper()} - {param_name}: {initial_value:.2f} / {ito_value:.2f} / {normalized_diff:.2f}")
        
        return "\n".join(output)

def reload_weights(model, ckpt_path, device):
    checkpoint = torch.load(ckpt_path, map_location=device)

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in checkpoint["model"].items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict, strict=False)