import dataclasses
import gc
import inspect
import logging
import os
from dataclasses import asdict, dataclass
from typing import Callable, List, NamedTuple, Optional, Sequence, Union

import numpy as np
import optuna
import torch
from optuna.pruners import HyperbandPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances
from torch.utils.tensorboard.writer import SummaryWriter

import wandb
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
from rl_algo_impls.runner.running_utils import (
    ALGOS,
    base_parser,
    get_device,
    hparam_dict,
    load_hyperparams,
    make_policy,
    set_seeds,
)
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
    MicrortsRewardDecayCallback,
)
from rl_algo_impls.shared.callbacks.optimize_callback import (
    Evaluation,
    OptimizeCallback,
    evaluation,
)
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
from rl_algo_impls.shared.stats import EpisodesStats
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper


@dataclass
class StudyArgs:
    load_study: bool
    study_name: Optional[str] = None
    storage_path: Optional[str] = None
    n_trials: int = 100
    n_jobs: int = 1
    n_evaluations: int = 4
    n_eval_envs: int = 8
    n_eval_episodes: int = 16
    timeout: Union[int, float, None] = None
    wandb_project_name: Optional[str] = None
    wandb_entity: Optional[str] = None
    wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
    wandb_group: Optional[str] = None
    virtual_display: bool = False


class Args(NamedTuple):
    train_args: Sequence[RunArgs]
    study_args: StudyArgs


def parse_args() -> Args:
    parser = base_parser()
    parser.add_argument(
        "--load-study",
        action="store_true",
        help="Load a preexisting study, useful for parallelization",
    )
    parser.add_argument("--study-name", type=str, help="Optuna study name")
    parser.add_argument(
        "--storage-path",
        type=str,
        help="Path of database for Optuna to persist to",
    )
    parser.add_argument(
        "--wandb-project-name",
        type=str,
        default="rl-algo-impls-tuning",
        help="WandB project name to upload tuning data to. If none, won't upload",
    )
    parser.add_argument(
        "--wandb-entity",
        type=str,
        help="WandB team. None uses the default entity",
    )
    parser.add_argument(
        "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
    )
    parser.add_argument(
        "--wandb-group", type=str, help="WandB group to group trials under"
    )
    parser.add_argument(
        "--n-trials", type=int, default=100, help="Maximum number of trials"
    )
    parser.add_argument(
        "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel"
    )
    parser.add_argument(
        "--n-evaluations",
        type=int,
        default=4,
        help="Number of evaluations during the training",
    )
    parser.add_argument(
        "--n-eval-envs",
        type=int,
        default=8,
        help="Number of envs in vectorized eval environment",
    )
    parser.add_argument(
        "--n-eval-episodes",
        type=int,
        default=16,
        help="Number of episodes to complete for evaluation",
    )
    parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization")
    parser.add_argument(
        "--virtual-display", action="store_true", help="Use headless virtual display"
    )
    # parser.set_defaults(
    #     algo=["a2c"],
    #     env=["CartPole-v1"],
    #     seed=[100, 200, 300],
    #     n_trials=5,
    #     virtual_display=True,
    # )
    train_dict, study_dict = {}, {}
    for k, v in vars(parser.parse_args()).items():
        if k in inspect.signature(StudyArgs).parameters:
            study_dict[k] = v
        else:
            train_dict[k] = v

    study_args = StudyArgs(**study_dict)
    # Hyperparameter tuning across algos and envs not supported
    assert len(train_dict["algo"]) == 1
    assert len(train_dict["env"]) == 1
    train_args = RunArgs.expand_from_dict(train_dict)

    if not all((study_args.study_name, study_args.storage_path)):
        hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env)
        config = Config(train_args[0], hyperparams, os.getcwd())
        if study_args.study_name is None:
            study_args.study_name = config.run_name(include_seed=False)
        if study_args.storage_path is None:
            study_args.storage_path = (
                f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}"
            )
    # Default set group name to study name
    study_args.wandb_group = study_args.wandb_group or study_args.study_name

    return Args(train_args, study_args)


def objective_fn(
    args: Sequence[RunArgs], study_args: StudyArgs
) -> Callable[[optuna.Trial], float]:
    def objective(trial: optuna.Trial) -> float:
        if len(args) == 1:
            return simple_optimize(trial, args[0], study_args)
        else:
            return stepwise_optimize(trial, args, study_args)

    return objective


def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float:
    base_hyperparams = load_hyperparams(args.algo, args.env)
    base_config = Config(args, base_hyperparams, os.getcwd())
    if args.algo == "a2c":
        hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
    else:
        raise ValueError(f"Optimizing {args.algo} isn't supported")
    config = Config(args, hyperparams, os.getcwd())

    wandb_enabled = bool(study_args.wandb_project_name)
    if wandb_enabled:
        wandb.init(
            project=study_args.wandb_project_name,
            entity=study_args.wandb_entity,
            config=asdict(hyperparams),
            name=f"{config.model_name()}-{str(trial.number)}",
            tags=study_args.wandb_tags,
            group=study_args.wandb_group,
            sync_tensorboard=True,
            monitor_gym=True,
            save_code=True,
            reinit=True,
        )
        wandb.config.update(args)

    tb_writer = SummaryWriter(config.tensorboard_summary_path)
    set_seeds(args.seed, args.use_deterministic_algorithms)

    env = make_env(
        config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
    )
    device = get_device(config, env)
    policy_factory = lambda: make_policy(
        args.algo, env, device, **config.policy_hyperparams
    )
    policy = policy_factory()
    algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)

    eval_env = make_eval_env(
        config,
        EnvHyperparams(**config.env_hyperparams),
        override_hparams={"n_envs": study_args.n_eval_envs},
    )
    optimize_callback = OptimizeCallback(
        policy,
        eval_env,
        trial,
        tb_writer,
        step_freq=config.n_timesteps // study_args.n_evaluations,
        n_episodes=study_args.n_eval_episodes,
        deterministic=config.eval_hyperparams.get("deterministic", True),
    )
    callbacks: List[Callback] = [optimize_callback]
    if config.hyperparams.microrts_reward_decay_callback:
        callbacks.append(MicrortsRewardDecayCallback(config, env))
    selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
    if selfPlayWrapper:
        callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
    try:
        algo.learn(config.n_timesteps, callbacks=callbacks)

        if not optimize_callback.is_pruned:
            optimize_callback.evaluate()
            if not optimize_callback.is_pruned:
                policy.save(config.model_dir_path(best=False))

        eval_stat: EpisodesStats = callback.last_eval_stat  # type: ignore
        train_stat: EpisodesStats = callback.last_train_stat  # type: ignore

        tb_writer.add_hparams(
            hparam_dict(hyperparams, vars(args)),
            {
                "hparam/last_mean": eval_stat.score.mean,
                "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
                "hparam/train_mean": train_stat.score.mean,
                "hparam/train_result": train_stat.score.mean - train_stat.score.std,
                "hparam/score": optimize_callback.last_score,
                "hparam/is_pruned": optimize_callback.is_pruned,
            },
            None,
            config.run_name(),
        )
        tb_writer.close()

        if wandb_enabled:
            wandb.run.summary["state"] = (  # type: ignore
                "Pruned" if optimize_callback.is_pruned else "Complete"
            )
            wandb.finish(quiet=True)

        if optimize_callback.is_pruned:
            raise optuna.exceptions.TrialPruned()

        return optimize_callback.last_score
    except AssertionError as e:
        logging.warning(e)
        return np.nan
    finally:
        env.close()
        eval_env.close()
        gc.collect()
        torch.cuda.empty_cache()


def stepwise_optimize(
    trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs
) -> float:
    algo = args[0].algo
    env_id = args[0].env
    base_hyperparams = load_hyperparams(algo, env_id)
    base_config = Config(args[0], base_hyperparams, os.getcwd())
    if algo == "a2c":
        hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
    else:
        raise ValueError(f"Optimizing {algo} isn't supported")

    wandb_enabled = bool(study_args.wandb_project_name)
    if wandb_enabled:
        wandb.init(
            project=study_args.wandb_project_name,
            entity=study_args.wandb_entity,
            config=asdict(hyperparams),
            name=f"{str(trial.number)}-S{base_config.seed()}",
            tags=study_args.wandb_tags,
            group=study_args.wandb_group,
            save_code=True,
            reinit=True,
        )

    score = -np.inf

    for i in range(study_args.n_evaluations):
        evaluations: List[Evaluation] = []

        for arg in args:
            config = Config(arg, hyperparams, os.getcwd())

            tb_writer = SummaryWriter(config.tensorboard_summary_path)
            set_seeds(arg.seed, arg.use_deterministic_algorithms)

            env = make_env(
                config,
                EnvHyperparams(**config.env_hyperparams),
                normalize_load_path=config.model_dir_path() if i > 0 else None,
                tb_writer=tb_writer,
            )
            device = get_device(config, env)
            policy_factory = lambda: make_policy(
                arg.algo, env, device, **config.policy_hyperparams
            )
            policy = policy_factory()
            if i > 0:
                policy.load(config.model_dir_path())
            algo = ALGOS[arg.algo](
                policy, env, device, tb_writer, **config.algo_hyperparams
            )

            eval_env = make_eval_env(
                config,
                EnvHyperparams(**config.env_hyperparams),
                normalize_load_path=config.model_dir_path() if i > 0 else None,
                override_hparams={"n_envs": study_args.n_eval_envs},
            )

            start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
            train_timesteps = (
                int((i + 1) * config.n_timesteps / study_args.n_evaluations)
                - start_timesteps
            )

            callbacks = []
            if config.hyperparams.microrts_reward_decay_callback:
                callbacks.append(
                    MicrortsRewardDecayCallback(
                        config, env, start_timesteps=start_timesteps
                    )
                )
            selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
            if selfPlayWrapper:
                callbacks.append(
                    SelfPlayCallback(policy, policy_factory, selfPlayWrapper)
                )
            try:
                algo.learn(
                    train_timesteps,
                    callbacks=callbacks,
                    total_timesteps=config.n_timesteps,
                    start_timesteps=start_timesteps,
                )

                evaluations.append(
                    evaluation(
                        policy,
                        eval_env,
                        tb_writer,
                        study_args.n_eval_episodes,
                        config.eval_hyperparams.get("deterministic", True),
                        start_timesteps + train_timesteps,
                    )
                )

                policy.save(config.model_dir_path())

                tb_writer.close()

            except AssertionError as e:
                logging.warning(e)
                if wandb_enabled:
                    wandb_finish("Error")
                return np.nan
            finally:
                env.close()
                eval_env.close()
                gc.collect()
                torch.cuda.empty_cache()

        d = {}
        for idx, e in enumerate(evaluations):
            d[f"{idx}/eval_mean"] = e.eval_stat.score.mean
            d[f"{idx}/train_mean"] = e.train_stat.score.mean
            d[f"{idx}/score"] = e.score
        d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item()
        d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item()
        score = np.mean([e.score for e in evaluations]).item()
        d["score"] = score

        step = i + 1
        wandb.log(d, step=step)

        print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}")
        trial.report(score, step)
        if trial.should_prune():
            if wandb_enabled:
                wandb_finish("Pruned")
            raise optuna.exceptions.TrialPruned()

    if wandb_enabled:
        wandb_finish("Complete")
    return score


def wandb_finish(state: str) -> None:
    wandb.run.summary["state"] = state  # type: ignore
    wandb.finish(quiet=True)


def optimize() -> None:
    from pyvirtualdisplay.display import Display

    train_args, study_args = parse_args()
    if study_args.virtual_display:
        virtual_display = Display(visible=False, size=(1400, 900))
        virtual_display.start()

    sampler = TPESampler(**TPESampler.hyperopt_parameters())
    pruner = HyperbandPruner()
    if study_args.load_study:
        assert study_args.study_name
        assert study_args.storage_path
        study = optuna.load_study(
            study_name=study_args.study_name,
            storage=study_args.storage_path,
            sampler=sampler,
            pruner=pruner,
        )
    else:
        study = optuna.create_study(
            study_name=study_args.study_name,
            storage=study_args.storage_path,
            sampler=sampler,
            pruner=pruner,
            direction="maximize",
        )

    try:
        study.optimize(
            objective_fn(train_args, study_args),
            n_trials=study_args.n_trials,
            n_jobs=study_args.n_jobs,
            timeout=study_args.timeout,
        )
    except KeyboardInterrupt:
        pass

    best = study.best_trial
    print(f"Best Trial Value: {best.value}")
    print("Attributes:")
    for key, value in list(best.params.items()) + list(best.user_attrs.items()):
        print(f"  {key}: {value}")

    df = study.trials_dataframe()
    df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False)
    print(df.to_markdown(index=False))

    fig1 = plot_optimization_history(study)
    fig1.write_image("opt_history.png")

    fig2 = plot_param_importances(study)
    fig2.write_image("param_importances.png")


if __name__ == "__main__":
    optimize()