from __future__ import annotations

import torch
from torch import amin  # Necessary for arcsin
import copy
import torch.nn as nn
import numpy as np

from scipy.optimize import curve_fit
from typing import Dict, Any, Tuple, List, Callable


def quantization(x, **params):
	return (torch.div(1, replace_num(params['_s'], num=0, to=10000)) * torch.sinh((params['_0'] * x)))


def dequantization(x, **params):
	return (torch.div(1, replace_num(params['_0'], num=0, to=10000)) * torch.log(domain_guard(((torch.tensor(-1) * torch.sqrt(domain_guard((torch.tensor(1) + (guarded_torch_power(params['_s'], torch.tensor(2)) * guarded_torch_power(x, torch.tensor(2)))), min=0.1, nan=0.1))) + (params['_s'] * x)), min=1e-5, nan=1e-5)))


def init_linear_scale(  # Symmetric scale. From the study folder
        x: torch.Tensor,
        **kwargs: Dict[str, Any],
    ) -> torch.Tensor:
    assert "bits" in kwargs, "bits must be provided."
    assert "params" in kwargs, "params must be provided."
    assert "qtz_func" in kwargs, "qtz_func must be provided."

    bits = kwargs.get('bits')
    params = kwargs.get('params')
    qtz_func = kwargs.get('qtz_func')

    x_ = x.transpose(0, 1)
    x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
    x_ = x_.transpose(0, 1)
    
    quant_min, quant_max = get_min_max_from_bits_signed(bits)
    min_vals, max_vals = torch.aminmax(x_, dim=1)
    min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
    max_vals = torch.max(max_vals, torch.zeros_like(max_vals))

    eps = torch.finfo(torch.float32).eps

    abs_max_val_per_ch = torch.max(-min_vals, max_vals)
    scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
    
    scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)

    return scale

    # Introduces some noise in scale
    # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
    scale = scale + 0.01 * torch.randn_like(scale)
    return scale


def init_params(x: torch.Tensor, **kwargs: Dict[str, Any]) -> Dict[str, nn.Parameter]:
	params = {
		'_0': init_rand(x, qtz_func=quantization, deqtz_func=dequantization, param='_0', params_list=['_0', '_s'], **kwargs),
	}
	params['_s'] = init_linear_scale(x, params=params, qtz_func=quantization, **kwargs)
	params = {k: nn.Parameter(v, requires_grad=False) for k, v in params.items()}
	
	if 'post_init_hook' in kwargs:
		kwargs['post_init_hook'](parameters=params)
	

	if 'post_train_hook' in kwargs:
		kwargs['post_train_hook'](parameters=params)
	
	return params


############### Numpy Qtz ###############


def np_quantization(x, _0, _s):
	return (np.divide(1, np_replace_num(_s, num=0, to=10000)) * np.sinh((_0 * x)))


def np_dequantization(x, _0, _s):
	return (np.divide(1, np_replace_num(_0, num=0, to=10000)) * np.log(np_domain_guard(((np.array(-1) * np.sqrt(np_domain_guard((np.array(1) + (np_guarded_power(_s, np.array(2)) * np_guarded_power(x, np.array(2)))), min=0.1, nan=0.1))) + (_s * x)), min=1e-5, nan=1e-5)))


def fit_func(x, _0, _s):
    x_ = np_quantization(x, _0, _s)
    x_ = np_dequantization(x_, _0, _s)
    return x_



############### HELPERS ###############

def domain_guard(
        x: torch.Tensor, 
        min: float = None, 
        max: float = None, 
        posinf: float = None,
        neginf: float = None,
        nan: float = None
    ) -> torch.Tensor:
    """Guard a tensor to a valid domain."""
    x = torch.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
    if min is not None or max is not None:
        x = torch.clamp(x, min=min, max=max)
    return x


def replace_num(x: torch.Tensor, num: float, to: float) -> torch.Tensor:
    """Replace a number in a tensor with another number.
    
    Args:
        x (torch.Tensor): The input tensor.
        num (float): The number to replace.
        to (float): The number to replace with.

    Returns:
        torch.Tensor: The tensor with the number replaced.
    """
    return torch.where(x == num, to, x)


def guarded_torch_power(x: torch.Tensor, exp: float) -> torch.Tensor:
    """Guard the power operation to a valid domain."""
    return torch.pow(x, exp) if exp >= 1 else torch.pow(torch.relu(x), exp)


def init_ones(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
    val = torch.amin(x, dim=1)
    return torch.ones_like(val, dtype=torch.float32, device=x.device)


def init_rand(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
    val = torch.amin(x, dim=1)
    return torch.randn_like(val, dtype=torch.float32, device=x.device)


def init_space_search(
        x: torch.Tensor,
        **kwargs: Dict[str, Any],
    ) -> torch.Tensor:

    def _build_initial_param(tensor: torch.Tensor, max_initial: int, n_params: int):
        """Generates the initial set of parameters. The first iteration generates 10 times more parameters."""
        for _ in range(n_params * 10):  # The first iteration generates 10 times more parameters
            yield init_rand(tensor) * max_initial  # Generates n_params in range [-max_initial, max_initial]

    def _search_param(tensors: List[torch.tensor], n_params):
        """Takes the best parameters and generates new parameters around the mean of the best parameters."""
        torch_tensors = torch.stack(tensors)
        min_vals, max_vals = torch.aminmax(torch_tensors, dim=0)
        abs_max_val_per_ch = torch.max(-min_vals, max_vals)
        mean = torch.mean(torch_tensors, dim=0)
        for _ in range(n_params):  # Generates n_params around the mean of the tensors
            yield torch.randn_like(min_vals) * abs_max_val_per_ch + mean

    def _calc(x, qtz_func, deqtz_func, **params):
        x_ = x.transpose(0, 1)
        x_ = qtz_func(x=x_, **params)
        x_ = deqtz_func(x=x_, **params)
        x_ = x_.transpose(0, 1)
        return x_

    assert "qtz_func" in kwargs, "qtz_func must be provided."
    assert "deqtz_func" in kwargs, "deqtz_func must be provided."
    assert "params_list" in kwargs, "params list must be provided."
    assert "param" in kwargs, "param must be provided."

    qtz_func = kwargs.get('qtz_func')
    deqtz_func = kwargs.get('deqtz_func')
    params_list = kwargs.get('params_list')
    param = kwargs.get('param')

    n_runs = 50                # Number of runs to try to find the best parameters
    n_random_params = 50       # Number of random parameters to generate
    n_best_to_pick = 5         # Number of best parameters to pick after each run
    max_initial = 10000        # Maximum value to initialize the parameters

    # Initializes the parameters
    base_params = { p: init_ones(x, **kwargs) for p in params_list if p != param }
    params = _build_initial_param(x, max_initial, n_random_params)

    # Performs the search
    for _ in range(n_runs):
        
        best_params = []
        for param_ in params:
            try:
                x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: param_})
                loss_ones = nn.MSELoss()(x, x_)

                if len(best_params) < n_best_to_pick:
                    best_params.append((param_, loss_ones.item()))
                    best_params = sorted(best_params, key=lambda x: x[1])
                elif loss_ones < best_params[-1][1]:
                    best_params[-1] = (param_, loss_ones.item())
                    best_params = sorted(best_params, key=lambda x: x[1])

            except Exception:  # The parameters might not be valid for the function's domain
                continue

        # Generates new parameters around the mean
        params = _search_param([p for p, _ in best_params], n_random_params)

    # Checks if the best parameter is better than the init_ones
    p_ones = init_ones(x, **kwargs)
    x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_ones})
    loss_ones = nn.MSELoss()(x, x_)

    # Checks if the best parameter is better than the init_rand
    p_rand = init_rand(x, **kwargs)
    x_ = _calc(x, qtz_func, deqtz_func, **base_params, **{param: p_rand})
    loss_rand = nn.MSELoss()(x, x_)

    if loss_rand < best_params[0][1] and loss_rand < loss_ones:
        return p_rand
    elif loss_ones < best_params[0][1] and loss_ones < loss_rand:
        return p_ones
    else:
        return best_params[0][0]


def init_linear_scale(  # Symmetric scale. From the study folder
        x: torch.Tensor,
        **kwargs: Dict[str, Any],
    ) -> torch.Tensor:
    assert "bits" in kwargs, "bits must be provided."
    assert "params" in kwargs, "params must be provided."
    assert "qtz_func" in kwargs, "qtz_func must be provided."

    bits = kwargs.get('bits')
    params = kwargs.get('params')
    qtz_func = kwargs.get('qtz_func')

    x_ = x.transpose(0, 1)
    x_ = qtz_func(x=x_, **params, _s=init_ones(x, **kwargs))
    x_ = x_.transpose(0, 1)
    
    quant_min, quant_max = get_min_max_from_bits_signed(bits)
    min_vals, max_vals = torch.aminmax(x_, dim=1)
    min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
    max_vals = torch.max(max_vals, torch.zeros_like(max_vals))

    eps = torch.finfo(torch.float32).eps

    abs_max_val_per_ch = torch.max(-min_vals, max_vals)
    scale = abs_max_val_per_ch / (float(quant_max - quant_min) / 2)
    
    scale = torch.clamp(scale, min=eps).to(dtype=torch.float32, device=min_vals.device)

    return scale

    # Introduces some noise in scale
    # If I don't introduce noise, the accuracy is going to be 0.0 and not learn anything
    scale = scale + 0.01 * torch.randn_like(scale)
    return scale


def init_non_linear_regression_fit(
        x: torch.Tensor,
        **kwargs: Dict[str, Any],
    ) -> torch.Tensor:

    assert "params_list" in kwargs, "params list must be provided."
    assert "np_fit_func" in kwargs, "np_fit_func must be provided."
    assert "p0" in kwargs, "p0 must be provided."
    np_fit_func = kwargs.get('np_fit_func')
    params_list = kwargs.get('params_list')
    p0 = kwargs.get('p0')

    def _fit(xdata: np.ndarray, ydata: np.ndarray, func: Callable, p0: List[float]):
        popt, _ = curve_fit(
            func, 
            xdata, 
            ydata, 
            maxfev=1000,
            p0=p0,
            method='lm'
        )
        return popt

    # 1. Needs to convert the torch tensor to numpy tensor
    xdata = x.cpu().numpy()

    # 2. Sorts the data so that it makes it easier to fit to it
    sorted_xdata = np.sort(xdata, axis=-1)

    p0 = {k: v.cpu().numpy() for k, v in p0.items()}
    params_list = sorted(params_list)  # We need to make sure that it matches the numpy fit func arg order

    # 3. Finds the best parameters for each channel
    try:
        params = []
        for i in range(sorted_xdata.shape[0]):
            xdata_ = sorted_xdata[i]
            p0_ = [p0[p][i] for p in params_list]
            ch_params = _fit(xdata_, xdata_, np_fit_func, p0_)
            params.append(ch_params)

        # 4. Builds the parameters
        result = {}
        for i, p in enumerate(params_list):
            result[p] = torch.tensor([p_[i] for p_ in params], dtype=torch.float32).to(x.device)

        return result

    except ValueError as e:
        print(f"Could not fit the function with error: {e}")
        print(f"Using fallback result...")
        return {
            k: torch.tensor(v, dtype=torch.float32).to(x.device) for k, v in p0.items()
        }


def init_zeros(x: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor:
    val = torch.amin(x, dim=1)
    return torch.zeros_like(val, dtype=torch.float32, device=x.device)


def init_inner_scale(tensor: torch.Tensor, _min: float = torch.inf, _max: float = torch.inf) -> torch.Tensor:
    # Calculate the original minimum and maximum values
    min_vals, max_vals = torch.aminmax(tensor, dim=-1)
    x_min = torch.min(min_vals, torch.zeros_like(min_vals))
    x_max = torch.max(max_vals, torch.zeros_like(max_vals))

    if _max is torch.inf:  # We do not need to scale the tensor. Just need to move it
        return torch.ones_like(x_min)
    
    # Calculate the scale factor
    scale = (_max - _min) / (x_max - x_min)
    return scale



############## Quant ###############

@torch.enable_grad()
def learn_parameters(
    x: torch.Tensor,
    params: Dict[str, nn.Parameter],
    qtz_func: nn.Module,
    deqtz_func: nn.Module,
    bits: int,
    target_dtype: torch.dtype,
    epochs: int = 1000,
    early_stop: bool = True,
    do_report: bool = False
) -> Tuple[Dict[str, nn.Parameter], torch.Tensor]:
    
    # Requires gradients in the parameters
    for p in params.values():
        p.requires_grad = True
        p.grad = None

    param_keys = list(params.keys())
    param_values = list(params.values())

    # Defines optimizer and loss function
    optimizer = torch.optim.Adam(param_values, lr=0.001)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.001, total_iters=epochs // 10)
    loss_fn = nn.MSELoss()

    # Contains the best loss and the best parameters 
    best_loss = float("inf")
    best_params = None

    # Used to stop the search early
    min_delta = 1e-7
    acc_loss = []
    percent_epochs_before_stop = 0.1

    for i in range(epochs):
        optimizer.zero_grad()

        quant = quantize(x, params, qtz_func, bits, target_dtype)
        dequant = dequantize(quant, params, deqtz_func, bits, x.dtype)
        loss = loss_fn(x, dequant)

        if loss.isnan() or loss.isinf():
            raise Exception("Loss is NaN or Inf. Stopping the search.")

        loss.backward()
        optimizer.step()
        scheduler.step()

        acc_loss.append(loss.item())

        # Reports loss every 10 steps
        if i % 10 == 0 and do_report:
            print(f"Epoch {i}: Loss {loss.item()}")

        # Optimizes the parameter search by storing the best loss and the parameters
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_params = copy.deepcopy({
                k: v for k, v in params.items() if k in param_keys
            })

        # We also stop the search if the loss has not considerably during the last 10% epochs
        if early_stop:
            epochs_before_stop = int(epochs * percent_epochs_before_stop)
            if i > epochs_before_stop and abs(acc_loss[i - epochs_before_stop] - acc_loss[i]) < min_delta:
                break

    # No longer requires gradients in the parameters
    for p in best_params.values():
        p.requires_grad = False
        p.grad = None
    
    if do_report:
        return best_params, acc_loss
    else:
        return best_params


def quantize(
    x: torch.Tensor,
    params: Dict[str, nn.Parameter],
    func: nn.Module,
    bits: int,
    target_dtype: torch.dtype = torch.int8
) -> torch.Tensor:
    quant_min, quant_max = get_min_max_from_bits_signed(bits)
    x = x.transpose(0, 1)  # Aligns shapes
    x = func(x=x, **params)
    x = x.transpose(0, 1)
    x = torch.clamp(round_func_BPDA(x), quant_min, quant_max).to(target_dtype)
    return x


def dequantize(
    x: torch.Tensor,
    params: Dict[str, nn.Parameter],
    func: nn.Module,
    bits: int,
    out_dtype: torch.dtype
) -> torch.Tensor:
    x = x.to(dtype=out_dtype)
    x = x.transpose(0, 1)
    x = func(x=x, **params)
    x = x.transpose(0, 1)
    return x


def round_func_BPDA(input):
    # This is equivalent to replacing round function (non-differentiable) with
    # an identity function (differentiable) only when backward.
    forward_value = torch.round(input)
    out = input.clone()
    out.data = forward_value.data
    return out


def get_min_max_from_bits_signed(bit_width: int) -> Tuple[int, int]:
    return -2 ** (bit_width - 1), 2 ** (bit_width - 1) - 1



############## Numpy ###############

def np_domain_guard(
        x: np.ndarray,
        min: float = None,
        max: float = None,
        posinf: float = None,
        neginf: float = None,
        nan: float = None
    ) -> np.ndarray:
    """Guard a tensor to a valid domain."""
    x = np.nan_to_num(x, posinf=posinf, neginf=neginf, nan=nan)
    if min is not None or max is not None:
        x = np.clip(x, min, max)
    return x


def np_replace_num(x: np.ndarray, num: float, to: float) -> np.ndarray:
    """Replace a number in a tensor with another number.
    
    Args:
        x (np.ndarray): The input tensor.
        num (float): The number to replace.
        to (float): The number to replace with.

    Returns:
        np.ndarray: The tensor with the number replaced.
    """
    return np.where(x == num, to, x)


def np_guarded_power(x: np.ndarray, exp: float) -> np.ndarray:
    """Guard the power operation to a valid domain."""
    return np.power(x, exp) if exp >= 1 else np.power(np.maximum(x, 0), exp)