#!/usr/bin/env python -u
# -*- coding: utf-8 -*-

# Import future compatibility features for Python 2/3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Import necessary libraries
import torch 
import torch.nn as nn
import numpy as np
from joblib import Parallel, delayed
from pesq import pesq  # PESQ metric for speech quality evaluation
import os 
import sys
import librosa  # Library for audio processing
import torchaudio  # Library for audio processing with PyTorch

# Constants
MAX_WAV_VALUE = 32768.0  # Maximum value for WAV files
EPS = 1e-6  # Small value to avoid division by zero

def read_and_config_file(input_path, decode=0):
    """Reads input paths from a file or directory and configures them for processing.

    Args:
        input_path (str): Path to the input directory or file.
        decode (int): Flag indicating if decoding should occur (1 for decode, 0 for standard read).

    Returns:
        list: A list of processed paths or dictionaries containing input and label paths.
    """
    processed_list = []

    # If decoding is requested, find files in a directory
    if decode:
        if os.path.isdir(input_path):
            processed_list = librosa.util.find_files(input_path, ext="wav")  # Look for WAV files
            if len(processed_list) == 0:
                processed_list = librosa.util.find_files(input_path, ext="flac")  # Fallback to FLAC files
        else:
            # Read paths from a file
            with open(input_path) as fid:
                for line in fid:
                    path_s = line.strip().split()  # Split line into parts
                    processed_list.append(path_s[0])  # Append the first part (input path)
        return processed_list

    # Read input-label pairs from a file
    with open(input_path) as fid:
        for line in fid:
            tmp_paths = line.strip().split()  # Split line into parts
            if len(tmp_paths) == 3:  # Expecting input, label, and duration
                sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1], 'duration': float(tmp_paths[2])}
            elif len(tmp_paths) == 2:  # Expecting input and label only
                sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1]}
            processed_list.append(sample)  # Append the sample dictionary
    return processed_list

def load_checkpoint(checkpoint_path, use_cuda):
    """Loads the model checkpoint from the specified path.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        use_cuda (bool): Flag indicating whether to use CUDA for loading.

    Returns:
        dict: The loaded checkpoint containing model parameters.
    """
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)  # Load using CUDA
    else:
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)  # Load to CPU
    return checkpoint

def get_learning_rate(optimizer):
    """Retrieves the current learning rate from the optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The optimizer instance.

    Returns:
        float: The current learning rate.
    """
    return optimizer.param_groups[0]["lr"]

def reload_for_eval(model, checkpoint_dir, use_cuda):
    """Reloads a model for evaluation from the specified checkpoint directory.

    Args:
        model (nn.Module): The model to be reloaded.
        checkpoint_dir (str): Directory containing checkpoints.
        use_cuda (bool): Flag indicating whether to use CUDA.

    Returns:
        None
    """
    print('Reloading from: {}'.format(checkpoint_dir))
    best_name = os.path.join(checkpoint_dir, 'last_best_checkpoint')  # Path to the best checkpoint
    ckpt_name = os.path.join(checkpoint_dir, 'last_checkpoint')  # Path to the last checkpoint
    if os.path.isfile(best_name):
        name = best_name 
    elif os.path.isfile(ckpt_name):
        name = ckpt_name
    else:
        print('Warning: No existing checkpoint or best_model found!')
        return
    
    with open(name, 'r') as f:
        model_name = f.readline().strip()  # Read the model name from the checkpoint file
    checkpoint_path = os.path.join(checkpoint_dir, model_name)  # Construct full checkpoint path
    print('Checkpoint path: {}'.format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
    #checkpoint = load_checkpoint(checkpoint_path, use_cuda)  # Load the checkpoint
    '''
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'], strict=False)  # Load model parameters
    else:
        model.load_state_dict(checkpoint, strict=False)
    '''
    if 'model' in checkpoint:
        pretrained_model = checkpoint['model']
    else:
        pretrained_model = checkpoint
    state = model.state_dict()
    for key in state.keys():
        if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
            state[key] = pretrained_model[key]
        elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
            state[key] = pretrained_model[key.replace('module.', '')]
        elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
            state[key] = pretrained_model['module.'+key]
    model.load_state_dict(state)
    print('=> Reloaded well-trained model {} for decoding.'.format(model_name))

def reload_model(model, optimizer, checkpoint_dir, use_cuda=True, strict=True):
    """Reloads the model and optimizer state from a checkpoint.

    Args:
        model (nn.Module): The model to be reloaded.
        optimizer (torch.optim.Optimizer): The optimizer to be reloaded.
        checkpoint_dir (str): Directory containing checkpoints.
        use_cuda (bool): Flag indicating whether to use CUDA.
        strict (bool): If True, requires keys in state_dict to match exactly.

    Returns:
        tuple: Current epoch and step.
    """
    ckpt_name = os.path.join(checkpoint_dir, 'checkpoint')  # Path to the checkpoint file
    if os.path.isfile(ckpt_name):
        with open(ckpt_name, 'r') as f:
            model_name = f.readline().strip()  # Read model name from checkpoint file
        checkpoint_path = os.path.join(checkpoint_dir, model_name)  # Construct full checkpoint path
        checkpoint = load_checkpoint(checkpoint_path, use_cuda)  # Load the checkpoint
        model.load_state_dict(checkpoint['model'], strict=strict)  # Load model parameters
        optimizer.load_state_dict(checkpoint['optimizer'])  # Load optimizer parameters
        epoch = checkpoint['epoch']  # Get current epoch
        step = checkpoint['step']  # Get current step
        print('=> Reloaded previous model and optimizer.')
    else:
        print('[!] Checkpoint directory is empty. Train a new model ...')
        epoch = 0  # Initialize epoch
        step = 0  # Initialize step
    return epoch, step

def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir, mode='checkpoint'):
    """Saves the model and optimizer state to a checkpoint file.

    Args:
        model (nn.Module): The model to be saved.
        optimizer (torch.optim.Optimizer): The optimizer to be saved.
        epoch (int): Current epoch number.
        step (int): Current training step number.
        checkpoint_dir (str): Directory to save the checkpoint.
        mode (str): Mode of the checkpoint ('checkpoint' or other).

    Returns:
        None
    """
    checkpoint_path = os.path.join(
        checkpoint_dir, 'model.ckpt-{}-{}.pt'.format(epoch, step))  # Construct checkpoint file path
    torch.save({'model': model.state_dict(),  # Save model parameters
                'optimizer': optimizer.state_dict(),  # Save optimizer parameters
                'epoch': epoch,  # Save epoch
                'step': step}, checkpoint_path)  # Save checkpoint to file

    # Save the checkpoint name to a file for easy access
    with open(os.path.join(checkpoint_dir, mode), 'w') as f:
        f.write('model.ckpt-{}-{}.pt'.format(epoch, step))
    print("=> Saved checkpoint:", checkpoint_path)

def setup_lr(opt, lr):
    """Sets the learning rate for all parameter groups in the optimizer.

    Args:
        opt (torch.optim.Optimizer): The optimizer instance whose learning rate needs to be set.
        lr (float): The new learning rate to be assigned.
    
    Returns:
        None
    """
    for param_group in opt.param_groups:
        param_group['lr'] = lr  # Update the learning rate for each parameter group


def pesq_loss(clean, noisy, sr=16000):
    """Calculates the PESQ (Perceptual Evaluation of Speech Quality) score between clean and noisy signals.

    Args:
        clean (ndarray): The clean audio signal.
        noisy (ndarray): The noisy audio signal.
        sr (int): Sample rate of the audio signals (default is 16000 Hz).

    Returns:
        float: The PESQ score or -1 in case of an error.
    """
    try:
        pesq_score = pesq(sr, clean, noisy, 'wb')  # Compute PESQ score
    except:
        # PESQ may fail due to silent periods in audio
        pesq_score = -1  # Assign -1 to indicate error
    return pesq_score


def batch_pesq(clean, noisy):
    """Computes the PESQ scores for batches of clean and noisy audio signals.

    Args:
        clean (list of ndarray): List of clean audio signals.
        noisy (list of ndarray): List of noisy audio signals.

    Returns:
        torch.FloatTensor: A tensor of normalized PESQ scores or None if any score is -1.
    """
    # Parallel processing for calculating PESQ scores for each pair of clean and noisy signals
    pesq_score = Parallel(n_jobs=-1)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy))
    pesq_score = np.array(pesq_score)  # Convert to NumPy array
    
    if -1 in pesq_score:  # Check for errors in PESQ calculations
        return None
    
    # Normalize PESQ scores to a scale of 0 to 1
    pesq_score = (pesq_score - 1) / 3.5  
    return torch.FloatTensor(pesq_score).to('cuda')  # Return normalized scores as a tensor


def power_compress(x):
    """Compresses the power of a complex spectrogram.

    Args:
        x (torch.Tensor): Input tensor with real and imaginary components.

    Returns:
        torch.Tensor: Compressed magnitude and phase representation of the input.
    """
    real = x[..., 0]  # Extract real part
    imag = x[..., 1]  # Extract imaginary part
    spec = torch.complex(real, imag)  # Create complex tensor from real and imaginary parts
    mag = torch.abs(spec)  # Compute magnitude
    phase = torch.angle(spec)  # Compute phase
    
    mag = mag**0.3  # Compress magnitude using power of 0.3
    real_compress = mag * torch.cos(phase)  # Reconstruct real part
    imag_compress = mag * torch.sin(phase)  # Reconstruct imaginary part
    return torch.stack([real_compress, imag_compress], 1)  # Stack compressed parts


def power_uncompress(real, imag):
    """Uncompresses the power of a compressed complex spectrogram.

    Args:
        real (torch.Tensor): Compressed real component.
        imag (torch.Tensor): Compressed imaginary component.

    Returns:
        torch.Tensor: Uncompressed complex spectrogram.
    """
    spec = torch.complex(real, imag)  # Create complex tensor from real and imaginary parts
    mag = torch.abs(spec)  # Compute magnitude
    phase = torch.angle(spec)  # Compute phase
    
    mag = mag**(1./0.3)  # Uncompress magnitude by raising to the power of 1/0.3
    real_uncompress = mag * torch.cos(phase)  # Reconstruct real part
    imag_uncompress = mag * torch.sin(phase)  # Reconstruct imaginary part
    return torch.stack([real_uncompress, imag_uncompress], -1)  # Stack uncompressed parts


def stft(x, args, center=False):
    """Computes the Short-Time Fourier Transform (STFT) of an audio signal.

    Args:
        x (torch.Tensor): Input audio signal.
        args (Namespace): Configuration arguments containing window type and lengths.
        center (bool): Whether to center the window.

    Returns:
        torch.Tensor: The computed STFT of the input signal.
    """
    win_type = args.win_type
    win_len = args.win_len
    win_inc = args.win_inc
    fft_len = args.fft_len

    # Select window type and create window tensor
    if win_type == 'hamming':
        window = torch.hamming_window(win_len, periodic=False).to(x.device)
    elif win_type == 'hanning':
        window = torch.hann_window(win_len, periodic=False).to(x.device)
    else:
        print(f"In STFT, {win_type} is not supported!")
        return
    
    # Compute and return the STFT
    return torch.stft(x, fft_len, win_inc, win_len, center=center, window=window, return_complex=False)


def istft(x, args, slen=None, center=False, normalized=False, onsided=None, return_complex=False):
    """Computes the inverse Short-Time Fourier Transform (ISTFT) of a complex spectrogram.

    Args:
        x (torch.Tensor): Input complex spectrogram.
        args (Namespace): Configuration arguments containing window type and lengths.
        slen (int, optional): Length of the output signal.
        center (bool): Whether to center the window.
        normalized (bool): Whether to normalize the output.
        onsided (bool, optional): If True, computes only the one-sided transform.
        return_complex (bool): If True, returns complex output.

    Returns:
        torch.Tensor: The reconstructed audio signal from the spectrogram.
    """
    win_type = args.win_type
    win_len = args.win_len
    win_inc = args.win_inc
    fft_len = args.fft_len

    # Select window type and create window tensor
    if win_type == 'hamming':
        window = torch.hamming_window(win_len, periodic=False).to(x.device)
    elif win_type == 'hanning':
        window = torch.hann_window(win_len, periodic=False).to(x.device)
    else:
        print(f"In ISTFT, {win_type} is not supported!")
        return

    try:
        # Attempt to compute ISTFT
        output = torch.istft(x, n_fft=fft_len, hop_length=win_inc, win_length=win_len,
                              window=window, center=center, normalized=normalized,
                              onesided=onsided, length=slen, return_complex=False)
    except:
        # Handle potential errors by converting x to a complex tensor
        x_complex = torch.view_as_complex(x)
        output = torch.istft(x_complex, n_fft=fft_len, hop_length=win_inc, win_length=win_len,
                              window=window, center=center, normalized=normalized,
                              onesided=onsided, length=slen, return_complex=False)
    return output


def compute_fbank(audio_in, args):
    """Computes the filter bank features from an audio signal.

    Args:
        audio_in (torch.Tensor): Input audio signal.
        args (Namespace): Configuration arguments containing window length, shift, and sampling rate.

    Returns:
        torch.Tensor: Computed filter bank features.
    """
    frame_length = args.win_len / args.sampling_rate * 1000  # Frame length in milliseconds
    frame_shift = args.win_inc / args.sampling_rate * 1000  # Frame shift in milliseconds

    # Compute and return filter bank features using Kaldi's implementation
    return torchaudio.compliance.kaldi.fbank(audio_in, dither=1.0, frame_length=frame_length,
                                             frame_shift=frame_shift, num_mel_bins=args.num_mels,
                                             sample_frequency=args.sampling_rate, window_type=args.win_type)