#!/usr/bin/env python3
import argparse
from gym_minigrid.window import Window
from utils import *
import gym
import pickle
from datetime import datetime

episodes = []
record = [False]


def update_caption_with_recording_indicator():
    new_caption = f"Recoding {'ON' if record[0] else 'OFF'}\n------------------\n\n" + window.caption.get_text()
    window.set_caption(new_caption)

def redraw(img):
    if not args.agent_view:
        img = env.render('rgb_array', tile_size=args.tile_size, mask_unobserved=args.mask_unobserved)

    # adds the rocding
    update_caption_with_recording_indicator()

    window.show_img(img)

def start_recording():
    record[0] = True
    print("Recording started")

    episodes[-1][-1]["record"]=True

def reset():
    episodes.append([])
    obs, info = env.reset_with_info()
    record[0] = False
    redraw(obs)

    episodes[-1].append(
        {
            "action": None,
            "info": info,
            "obs": obs,
            "reward": None,
            "done": None,
            "record": record[0],
        }
    )


def step(action):
    if type(action) == np.ndarray:
        obs, reward, done, info = env.step(action)
    else:
        action = [int(action), np.nan, np.nan]
        obs, reward, done, info = env.step(action)

    episodes[-1].append(
        {
            "action": action,
            "info": info,
            "obs": obs,
            "reward": reward,
            "done": done,
            "record": record[0],
        }
    )
    redraw(obs)

    if done:
        print('done!')
        print('Reward=%.2f' % (reward))

        # reset and add initial state to episodes
        reset()

    else:
        print('\nStep=%s' % (env.step_count))


    # filter steps without recording
    episodes_to_save = [[s for s in ep if s["record"]] for ep in episodes]
    episodes_to_save = [ep for ep in episodes_to_save if len(ep) > 0]

    # set first recording step to be as if it was just reset (the real first step)
    for ep_to_save in episodes_to_save:
        ep_to_save[0]["action"]=None
        ep_to_save[0]["reward"]=None
        ep_to_save[0]["done"]=None


    # picle the episodes
    dump_pickle = Path(output_dir) / "episodes.pkl"
    print(f"Saving {len(episodes_to_save)} episodes ({[len(e) for e in episodes_to_save]}) to : {dump_pickle}")

    with open(dump_pickle, 'wb') as f:
        pickle.dump(episodes_to_save, f)


def key_handler(event):

    print('pressed', event.key)

    if event.key == 'r':
        start_recording()
        return

    if event.key == 'escape':
        window.close()
        return

    if event.key == 's':
        reset()
        return

    if event.key == 'tab':
        step(np.array([np.nan, np.nan, np.nan]))
        return

    if event.key == 'shift':
        step(np.array([np.nan, np.nan, np.nan]))
        return

    if event.key == 'left':
        step(env.actions.left)
        return
    if event.key == 'right':
        step(env.actions.right)
        return
    if event.key == 'up':
        step(env.actions.forward)
        return
    if event.key == 't':
        step(env.actions.speak)
        return

    if event.key == '1':
        step(np.array([np.nan, 0, 0]))
        return
    if event.key == '2':
        step(np.array([np.nan, 0, 1]))
        return
    if event.key == '3':
        step(np.array([np.nan, 1, 0]))
        return
    if event.key == '4':
        step(np.array([np.nan, 1, 1]))
        return
    if event.key == '5':
        step(np.array([np.nan, 2, 2]))
        return
    if event.key == '6':
        step(np.array([np.nan, 1, 2]))
        return
    if event.key == '7':
        step(np.array([np.nan, 2, 1]))
        return
    if event.key == '8':
        step(np.array([np.nan, 1, 3]))
        return
    if event.key == 'p':
        step(np.array([np.nan, 3, 3]))
        return

    # Spacebar
    if event.key == ' ':
        step(env.actions.toggle)
        return
    if event.key == '9':
        step(env.actions.pickup)
        return
    if event.key == '0':
        step(env.actions.drop)
        return

    if event.key == 'enter':
        step(env.actions.done)
        return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--env",
        help="gym environment to load",
        # default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1",
        # default="SocialAI-ColorBoxesLLMCSParamEnv-v1",
        default="SocialAI-ColorLLMCSParamEnv-v1",
    )
    parser.add_argument(
        "--seed",
        type=int,
        help="random seed to generate the environment with",
        default=-1
    )
    parser.add_argument(
        "--tile_size",
        type=int,
        help="size at which to render tiles",
        default=32
    )
    parser.add_argument(
        '--agent_view',
        default=False,
        help="draw the agent sees (partially observable view)",
        action='store_true'
    )
    parser.add_argument(
        '--mask-unobserved',
        default=False,
        help="mask cells that are not observed by the agent",
        action='store_true'
    )
    parser.add_argument(
        '--save-dir',
        default="./llm_data/in_context_examples/",
        help="file where to save episodes",
    )
    parser.add_argument(
        '--load',
        default=None,
        help="Load in context examples to append to",
    )
    parser.add_argument(
        '--name',
        default="in_context",
        help="additional name tag for the episodes",
    )
    parser.add_argument(
        '--draw-tree',
        action="store_true",
        help="Draw the sampling treee",
    )

    # Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True
    parser.add_argument("--env-args", nargs='*', default=None)

    args = parser.parse_args()

    env = gym.make(args.env, **env_args_str_to_dict(args.env_args))

    timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    output_dir = Path(args.save_dir) / f"{args.name}_{args.env}_{timestamp}"
    os.makedirs(output_dir, exist_ok=True)

    if args.load:
        with open(args.load, 'rb') as f:
            episodes = pickle.load(f)

    if args.draw_tree:
        # draw tree
        env.parameter_tree.draw_tree(
            filename=output_dir / f"/{args.env}_raw_tree",
            ignore_labels=["Num_of_colors"],
        )

    if args.seed >= 0:
        env.seed(args.seed)

    window = Window('gym_minigrid - ' + args.env, figsize=(6, 4))
    window.reg_key_handler(key_handler)
    env.window = window

    reset()
    # # a trick to make the first image appear right away
    # # this action is not saved
    # obs, _, _, _ = env.step(np.array([np.nan, np.nan, np.nan]))
    # redraw(obs)

    # Blocking event loop
    window.show(block=True)