import itertools
import functools
import math
import multiprocessing
from pathlib import Path

import matplotlib
matplotlib.rcParams.update({'font.size': 24})
matplotlib.rcParams.update({
  "text.usetex": True,
  "text.latex.preamble": r"\usepackage{biolinum} \usepackage{libertineRoman} \usepackage{libertineMono} \usepackage{biolinum} \usepackage[libertine]{newtxmath}",
  'ps.usedistiller': "xpdf",
})

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
from scipy.stats import wasserstein_distance

import pose_estimation


def cub(x, a, b, c):
    x2 = x * x
    x3 = x2 * x

    y = a * x3 + b * x2 + c * x

    return y


def subsample(a, p=0.0005, seed=0):
    np.random.seed(seed)
    N = len(a)
    inds = np.random.choice(range(N), size=int(p * N))
    a = a[inds].copy()

    return a


def read_cos_opt(path, fname="cos_hist.npy"):
    cos_opt = []
    for p in Path(path).rglob(fname):
        d = np.load(p)
        cos_opt.append(d)

    cos_opt = np.array(cos_opt)

    return cos_opt


def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None):
    cos_opt = read_cos_opt(cos_opt_dir)
    angle_opt = np.arccos(cos_opt)
    angle_opt2 = cub(angle_opt, *params)

    cos_opt2 = np.cos(angle_opt2)
    cos_smpl = np.load(hist_smpl_fpath)
    # cos_smpl = subsample(cos_smpl)
    print(cos_smpl.shape)

    cos_smpl = np.clip(cos_smpl, -1, 1)

    cos_opt = angle_opt
    cos_opt2 = angle_opt2
    cos_smpl = np.arccos(cos_smpl)

    cos_opt = 180 / math.pi * cos_opt
    cos_opt2 = 180 / math.pi * cos_opt2
    cos_smpl = 180 / math.pi * cos_smpl
    max_range = 90  # math.pi / 2

    xticks = [0, 15, 30, 45, 60, 75, 90]
    for idx, bone in enumerate(pose_estimation.SKELETON):
        i, j = bone
        i_name = pose_estimation.KPS[i]
        j_name = pose_estimation.KPS[j]
        if i_name != "Left Upper Leg":
            continue

        name = f"{i_name}_{j_name}"

        gs = gridspec.GridSpec(2, 4)
        fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300)

        ax0 = fig.add_subplot(gs[0, 0])
        ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True)
        ax0.set_xticks(xticks)
        ax0.tick_params(labelbottom=False, labelleft=True)

        ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
        ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True)
        ax1.set_xticks(xticks)

        if xy is not None:
            ax2 = fig.add_subplot(gs[:, 1:3])
            ax2.plot(xy[0], xy[1], linewidth=8)
            ax2.plot(xy[0], xy[0], linewidth=4, linestyle="dashed")
            ax2.set_xticks(xticks)
            ax2.set_yticks(xticks)

        ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)
        ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True)
        ax3.set_xticks(xticks)
        ax3.tick_params(labelbottom=False, labelleft=False)

        ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1)
        alpha = 0.5
        ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{B}_i$", alpha=alpha)
        ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$f(\mathcal{B}_i)$", alpha=alpha)
        ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{A}_i$", alpha=alpha)
        ax4.set_xticks(xticks)
        ax4.tick_params(labelbottom=True, labelleft=False)
        ax4.legend()

        fig.savefig(out_dir / f"hist_{name}.png")
        plt.close()


def kldiv(p_hist, q_hist):
    wd = wasserstein_distance(p_hist, q_hist)

    return wd


def calc_histogram(x, bins=10, range=(0, 1)):
    h, _ = np.histogram(x, bins=bins, range=range, density=True)

    return h

def step(params, angles_opt, p_hist, bone_idx=None):
    if sum(params) > 1:
        return math.inf, params

    kl = 0
    for i, _ in enumerate(pose_estimation.SKELETON):
        if bone_idx is not None and i != bone_idx:
            continue

        angles_opt2 = cub(angles_opt[:, i], *params)
        if angles_opt2.max() > 1 or angles_opt2.min() < 0:
            kl = math.inf

            break

        q_hist = calc_histogram(angles_opt2)

        kl += kldiv(p_hist[i], q_hist)

    return kl, params


def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):
    cos_opt = read_cos_opt(cos_opt_dir)
    angles_opt = np.arccos(cos_opt) / (math.pi / 2)
    cos_smpl = np.load(hist_smpl_fpath)
    # cos_smpl = subsample(cos_smpl)
    print(cos_smpl.shape)
    cos_smpl = np.clip(cos_smpl, -1, 1)
    mask = cos_smpl <= 1
    assert np.all(mask), (~mask).mean()
    mask = cos_smpl >= 0
    assert np.all(mask), (~mask).mean()
    angles_smpl = np.arccos(cos_smpl) / (math.pi / 2)
    p_hist = [
        calc_histogram(angles_smpl[:, i])
        for i, _ in enumerate(pose_estimation.SKELETON)
    ]

    with multiprocessing.Pool(8) as p:
        results = list(
            tqdm.tqdm(
                p.imap_unordered(
                    functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx),
                    itertools.product(
                        np.linspace(0, 20, 100),
                        np.linspace(-20, 20, 200),
                        np.linspace(-20, 1, 100),
                    ),
                ),
                total=(100 * 200 * 100),
            )
        )

    kls, params = zip(*results)
    ind = np.argmin(kls)
    best_params = params[ind]

    print(kls[ind], best_params)

    inds = np.argsort(kls)
    for i in inds[:10]:
        print(kls[i])
        print(params[i])
        print()

    return best_params


def main():
    cos_opt_dir = "paper_single2_150mse"
    hist_smpl_fpath = "./data/hist_smpl.npy"
    # hist_smpl_fpath = "./testtest.npy"
    params = optimize(cos_opt_dir, hist_smpl_fpath)
    # params = (1.2121212121212122, -1.105527638190953, 0.787878787878789)
    # params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633)
    print(params)

    x = np.linspace(0, math.pi / 2, 100)
    y = cub(x / (math.pi / 2), *params) * (math.pi / 2)
    x = x * 180 / math.pi
    y = y * 180 / math.pi

    out_dir = Path("hists")
    out_dir.mkdir(parents=True, exist_ok=True)
    plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y))

    plt.figure(figsize=(4, 4), dpi=300)
    plt.plot(x, y, linewidth=6)
    plt.plot(x, x, linewidth=2, linestyle="dashed")
    xticks = [0, 15, 30, 45, 60, 75, 90]
    plt.xticks(xticks)
    plt.yticks(xticks)
    plt.axis("equal")
    plt.tight_layout()
    plt.savefig(out_dir / "new_out.png")


if __name__ == "__main__":
    main()