import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

tokenizer.pad_token_id = tokenizer.eos_token_id
print("Loading finished.")

print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

STYLE = """
.custom-container {
	display: grid;
	align-items: center;
    margin: 0!important;
    overflow: auto;
}
.prose ul ul {
    font-size: 10px!important;
}
.prose li {
    margin-bottom: 0!important;
}
.prose table {
    margin-bottom: 0!important;
}

.prose td, th {
    padding-left: 2px;
    padding-right: 2px;
    padding-top: 0;
    padding-bottom: 0;
}

.tree {
	padding: 0px;
	margin: 0!important;
	box-sizing: border-box;
    font-size: 10px;
	width: 100%;
	height: auto;
	text-align: center;
    display:inline-block;
}

#root {
    display: inline-grid!important;
    width:auto!important;
    min-width: 220px;
}

.tree ul {
    padding-left: 20px;
    position: relative;
    transition: all 0.5s ease 0s;
    display: flex;
    flex-direction: column;
    gap: 10px;
    margin: 0px !important;
}
.tree li {
    display: flex;
    text-align: center;
    list-style-type: none;
    position: relative;
    padding-left: 20px;
    transition: all 0.5s ease 0s;
    flex-direction: row;
    justify-content: start;
    align-items: center;
}

.tree li::before, .tree li::after {
    content: "";
    position: absolute;
    left: 0px;
    border-left: 1px solid var(--body-text-color);
    width: 20px;
}
.tree li::before {
    top: 0;
    height:50%;
}
.tree li::after {
    top: 50%;
    height: 55%;
    bottom: auto;
    border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
    display: none;
}

.tree li:first-child::before, .tree li:last-child::after {
    border: 0 none;
}
.tree li:last-child::before {
	border-bottom: 1px solid var(--body-text-color);
	border-radius: 0px 0px 0px 5px;
	-webkit-border-radius: 0px 0px 0px 5px;
	-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
	border-radius: 5px 0 0 0;
	-webkit-border-radius: 5px 0 0 0;
	-moz-border-radius: 5px 0 0 0;
}

.tree ul ul::before {
    content: "";
    position: absolute;
    left: 0;
    top: 50%;
    border-top: 1px solid var(--body-text-color);
    width: 20px;
    height: 0;
}
.tree ul:has(> li:only-child)::before {
    width:40px;
}

a:before {
    border-right: 1px solid var(--body-text-color);
    border-bottom: 1px solid var(--body-text-color);
    content: "";
    position: absolute;
    width: 10px;
    left: 0px;
    height: 10px;
    top: 50%;
    margin-top: -5px;
    margin-left: 6px;
    transform: rotate(315deg);
}


.tree li a {
	border: 1px solid var(--body-text-color);
	padding: 5px;
	border-radius: 5px;
	text-decoration-line: none;
	border-radius: 5px;
	transition: .5s;
    width: 260px;
    display: flex;
    align-items: center;
    justify-content: space-around;
}
.tree li a span {
	padding: 5px;
	font-size: 12px;
	text-transform: uppercase;
	letter-spacing: 1px;
	font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
	background: #ffedd5;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
	border-color: #7c2d12;
}
.chosen {
    background-color: #ea580c;
    width:auto!important;
}
"""


def generate_nodes(token, node, step=0):
    """Recursively generate HTML for the tree nodes."""

    html_content = f" <li> <a href='#' class='{('chosen' if node.table is None else '')}' id='{('root' if step==0 else '')}'> <span> <b>{token}</b> </span> "
    html_content += node.table if node.table is not None else ""
    html_content += "</a>"
    if len(node.children.keys()) > 0:
        html_content += "<ul> "
        for token, subnode in node.children.items():
            html_content += generate_nodes(token, subnode, step=step + 1)
        html_content += "</ul>"
    html_content += "</li>"
    return html_content


def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None):
    markdown_table = """
    <table>
        <tr>
            <th><b>Token</b></th>
            <th><b>Step score</b></th>
            <th><b>Total score</b></th>
        </tr>"""
    for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
        token = tokenizer.decode([token_idx])
        item_class = ""
        if chosen_tokens and token in chosen_tokens:
            item_class = "chosen"
        markdown_table += f"""
        <tr class={item_class}>
            <td>{token}</td>
            <td>{scores[token_idx]:.4f}</td>
            <td>{scores[token_idx] + sequence_prob:.4f}</td>
        </tr>"""
    markdown_table += """
    </table>"""
    return markdown_table


def generate_html(start_sentence, original_tree):

    html_output = """<div class="custom-container">
				<div class="tree">
                <ul>"""
    html_output += generate_nodes(start_sentence, original_tree)

    html_output += """
        </ul>
        </div>
    </body>
    """
    return html_output


import pandas as pd
from typing import Dict
from dataclasses import dataclass


@dataclass
class BeamNode:
    cumulative_score: float
    table: str
    current_sentence: str
    children: Dict[str, "BeamNode"]


def generate_beams(start_sentence, scores, sequences, beam_indices):
    print(tokenizer.batch_decode(sequences))
    sequences = sequences.cpu().numpy()
    original_tree = BeamNode(
        cumulative_score=0, table=None, current_sentence=start_sentence, children={}
    )
    n_beams = len(scores[0])
    beam_trees = [original_tree] * n_beams
    for step, step_scores in enumerate(scores):
        (
            top_token_indexes,
            top_cumulative_scores,
            beam_indexes,
            current_completions,
            top_tokens,
        ) = ([], [], [], [], [])
        for beam_ix in range(n_beams):
            current_beam = beam_trees[beam_ix]
            # Get top cumulative scores for the current beam
            current_top_token_indexes = list(
                np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
            )
            top_token_indexes += current_top_token_indexes
            top_cumulative_scores += list(
                np.array(scores[step][beam_ix][current_top_token_indexes])
                + current_beam.cumulative_score
            )
            beam_indexes += [beam_ix] * n_beams
            current_completions += [beam_trees[beam_ix].current_sentence] * n_beams
            top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]

        top_df = pd.DataFrame.from_dict(
            {
                "token_index": top_token_indexes,
                "cumulative_score": top_cumulative_scores,
                "beam_index": beam_indexes,
                "current_completions": current_completions,
                "token": top_tokens,
            }
        )
        maxes = top_df.groupby(["token_index", "current_completions"])[
            "cumulative_score"
        ].idxmax()

        top_df = top_df.loc[maxes]

        # Sort all top probabilities and keep top n_beams
        top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
            :n_beams
        ]

        # Write the scores table - one per beam source?
        # Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
        for beam_ix in reversed(list(range(n_beams))):
            current_beam = beam_trees[beam_ix]
            selected_tokens = top_df_selected.loc[
                top_df_selected["beam_index"] == beam_ix
            ]
            markdown_table = generate_markdown_table(
                step_scores[beam_ix, :],
                current_beam.cumulative_score,
                chosen_tokens=list(selected_tokens["token"].values),
            )
            beam_trees[beam_ix].table = markdown_table

        # Add new children for each beam
        cumulative_scores = [beam.cumulative_score for beam in beam_trees]
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])

            # Update the source tree
            source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])

            previous_len = len(str(original_tree))
            beam_trees[source_beam_ix].children[current_token_choice] = BeamNode(
                table=None,
                children={},
                current_sentence=beam_trees[source_beam_ix].current_sentence
                + current_token_choice,
                cumulative_score=cumulative_scores[source_beam_ix]
                + scores[step][source_beam_ix][current_token_choice_ix].numpy(),
            )
            assert (
                len(str(original_tree)) > previous_len
            ), "Original tree has not increased size"

        # Reassign all beams at once
        beam_trees = [
            beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
            for beam_ix in range(n_beams)
        ]

        # Advance all beams by one token
        for beam_ix in range(n_beams):
            current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
            current_token_choice = tokenizer.decode([current_token_choice_ix])
            beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice]
    return original_tree

@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams):
    inputs = tokenizer([input_text], return_tensors="pt")

    outputs = model.generate(
        **inputs,
        max_new_tokens=number_steps,
        num_beams=number_beams,
        num_return_sequences=number_beams,
        return_dict_in_generate=True,
        output_scores=True,
        top_k=5,
        do_sample=False,
    )

    original_tree = generate_beams(
        input_text,
        outputs.scores[:],
        outputs.sequences[:, :],
        outputs.beam_indices[:, :],
    )
    html = generate_html(input_text, original_tree)
    print(html)
    return html


with gr.Blocks(
    theme=gr.themes.Soft(
        text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow
    ),
    css=STYLE,
) as demo:
    text = gr.Textbox(label="Sentence to decode from", value="Today is")
    steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
    beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
    button = gr.Button()
    out = gr.Markdown(label="Output")
    button.click(get_beam_search_html, inputs=[text, steps, beams], outputs=out)

demo.launch()