"""
Gradio interface for plotting attention.
"""

import chess
import gradio as gr
import torch
import uuid
import re

from . import constants, state, visualisation


def compute_cache(
    game_pgn,
    board_fen,
    attention_layer,
    attention_head,
    comp_index,
    state_cache,
    state_board_index,
):
    if game_pgn == "" and board_fen != "":
        board = chess.Board(board_fen)
        fen_list = [board.fen()]
    else:
        board = chess.Board()
        fen_list = [board.fen()]
        for move in game_pgn.split():
            if move.endswith("."):
                continue
            try:
                board.push_san(move)
                fen_list.append(board.fen())
            except ValueError:
                gr.Warning(f"Invalid move {move}, stopping before it.")
                break
    state_cache = [(fen, state.model_cache(fen)) for fen in fen_list]
    return (
        *make_plot(
            attention_layer, attention_head, comp_index, state_cache, state_board_index
        ),
        state_cache,
    )


def make_plot(
    attention_layer,
    attention_head,
    comp_index,
    state_cache,
    state_board_index,
):
    if state_cache is None:
        gr.Warning("Cache not computed!")
        return None, None, None, None, None

    fen, (out, cache) = state_cache[state_board_index]
    attn_list = [a[0, attention_head - 1] for a in cache[attention_layer - 1]]
    prompt_attn, *comp_attn = attn_list
    comp_attn.insert(0, prompt_attn[-1:])
    comp_attn = [a.squeeze(0) for a in comp_attn]
    if len(comp_attn) != 5:
        raise NotImplementedError("This is not implemented yet.")

    config_total = meta_total = dump_total = 0
    config_done = False
    heatmap = torch.zeros(64)
    h_index = 0
    for i, t_o in enumerate(out[0]):
        try:
            t_attn = comp_attn[comp_index - 1][i]
            if (i < 3) or (i > len(out[0]) - 10):
                dump_total += t_attn
                continue
            t_str = state.model.tokenizer.decode(t_o)
            if t_str.startswith(" ") and h_index > 0:
                config_done = True
            if not config_done:
                if t_str == "/":
                    dump_total += t_attn
                    continue
                t_str = re.sub(r"\d", lambda m: "0" * int(m.group(0)), t_str)
                config_total += t_attn
                t_str_len = len(t_str.strip())
                pre_t_attn = t_attn / t_str_len
                for j in range(t_str_len):
                    heatmap[h_index + j] = pre_t_attn
                h_index += t_str_len
            else:
                meta_total += t_attn
        except IndexError:
            break
    raw_attention = comp_attn[comp_index - 1]
    highlited_tokens = [
        (state.model.tokenizer.decode(out[0][i]), raw_attention[i])
        for i in range(len(raw_attention))
    ]
    uci_move = state.model.tokenizer.decode(out[0][-5:-1]).strip()
    board = chess.Board(fen)
    heatmap = heatmap.view(8, 8).flip(0).view(64)
    move = chess.Move.from_uci(uci_move)
    svg_board, fig = visualisation.render_heatmap(
        board, heatmap, arrows=[(move.from_square, move.to_square)]
    )
    info = (
        f"[Completion] Complete: '{state.model.tokenizer.decode(out[0][-5:])}'"
        f" Chosen: '{state.model.tokenizer.decode(out[0][-5:][comp_index-1])}'"
        f"\n[Distribution] Config: {config_total:.2f} Meta: {meta_total:.2f} Dump: {dump_total:.2f}"
    )
    id = str(uuid.uuid4())
    with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
        f.write(svg_board)
    return (
        board.fen(),
        info,
        fig,
        f"{constants.FIGURE_DIRECTORY}/board_{id}.svg",
        highlited_tokens,
    )


def previous_board(
    attention_layer,
    attention_head,
    comp_index,
    state_cache,
    state_board_index,
):
    state_board_index -= 1
    if state_board_index < 0:
        gr.Warning("Already at first board.")
        state_board_index = 0
    return (
        *make_plot(
            attention_layer, attention_head, comp_index, state_cache, state_board_index
        ),
        state_board_index,
    )


def next_board(
    attention_layer,
    attention_head,
    comp_index,
    state_cache,
    state_board_index,
):
    state_board_index += 1
    if state_board_index >= len(state_cache):
        gr.Warning("Already at last board.")
        state_board_index = len(state_cache) - 1
    return (
        *make_plot(
            attention_layer, attention_head, comp_index, state_cache, state_board_index
        ),
        state_board_index,
    )


with gr.Blocks() as interface:
    with gr.Row():
        with gr.Column():
            with gr.Group():
                gr.Markdown(
                    "Specify the game PGN or FEN string that you want to analyse (PGN overrides FEN)."
                )
                game_pgn = gr.Textbox(
                    label="Game PGN",
                    lines=1,
                )
                board_fen = gr.Textbox(
                    label="Board FEN",
                    lines=1,
                    max_lines=1,
                )
            compute_cache_button = gr.Button("Compute cache")
            with gr.Group():
                with gr.Row():
                    attention_layer = gr.Slider(
                        label="Attention layer",
                        minimum=1,
                        maximum=12,
                        step=1,
                        value=1,
                    )
                    attention_head = gr.Slider(
                        label="Attention head",
                        minimum=1,
                        maximum=12,
                        step=1,
                        value=1,
                    )
                    comp_index = gr.Slider(
                        label="Completion index",
                        minimum=1,
                        maximum=6,
                        step=1,
                        value=1,
                    )
                with gr.Row():
                    previous_board_button = gr.Button("Previous board")
                    next_board_button = gr.Button("Next board")
            current_board_fen = gr.Textbox(
                label="Board FEN",
                lines=1,
                max_lines=1,
            )
            info = gr.Textbox(
                label="Info",
                lines=1,
                info=(
                    "'Config' refers to the board configuration tokens."
                    "\n'Meta' to the additional board tokens (like color or castling)."
                    "\n'Dump' to the rest of the tokens (including '/')."
                ),
            )
            gr.Markdown(
                "Note that only the 'Config' attention is plotted.\n\nSee below for the raw attention."
            )
            raw_attention_html = gr.HighlightedText(
                label="Raw attention",
            )
        with gr.Column():
            image_board = gr.Image(label="Board")
            colorbar = gr.Plot(label="Colorbar")

    static_inputs = [
        attention_layer,
        attention_head,
        comp_index,
    ]
    static_outputs = [
        current_board_fen,
        info,
        colorbar,
        image_board,
        raw_attention_html,
    ]

    state_cache = gr.State(value=None)
    state_board_index = gr.State(value=0)
    compute_cache_button.click(
        compute_cache,
        inputs=[game_pgn, board_fen, *static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs, state_cache],
    )

    previous_board_button.click(
        previous_board,
        inputs=[*static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs, state_board_index],
    )
    next_board_button.click(
        next_board,
        inputs=[*static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs, state_board_index],
    )
    attention_layer.change(
        make_plot,
        inputs=[*static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs],
    )
    attention_head.change(
        make_plot,
        inputs=[*static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs],
    )
    comp_index.change(
        make_plot,
        inputs=[*static_inputs, state_cache, state_board_index],
        outputs=[*static_outputs],
    )