from transformers import AutoTokenizer
import gradio as gr
import random

checkpoint = "dslim/bert-base-NER"
checkpoints = [
    checkpoint,
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "microsoft/phi-2",
    "openai/whisper-large-v3",
    "NousResearch/Nous-Hermes-2-Yi-34B",
    "bert-base-cased"
]

placeholder = "Type anything in this text box and hit Tokenize!"
sequences = [
    "The quick brown šŸ¦Š fox jumps over the lazy šŸ• dog!",
    "How vexingly ā© quick daft šŸ¦“ zebras jump?",
    "Pack my šŸ“¦ box with five dozen šŸ· liquor jugs.",
    "The five šŸ„Š boxing šŸ§™ā€ā™‚ļø wizards jump quickly~",
    "While making deep ā›ļø excavations we found some quaint bronze šŸ’ jewelry!",
    "Whenever the šŸ¦Š fox jumped, the šŸæļø squirrel gazed suspiciously...",
    "We promptly šŸ§‘ā€āš–ļø judged antique ivory buckles for the next šŸ† prize."
    ]

def randomize_sequence():
    return random.choice(sequences)

sequence = randomize_sequence

def load_vocab(target_model, current_model):
    checkpoint = target_model
    if target_model == current_model:
        gr.Info(f"Tokenizer already loaded: {checkpoint}")
    else:
        load_tokenizer(checkpoint)
        gr.Info(f"Tokenizer loaded: {checkpoint}")
    vocab = dict(sorted(tokenizer.vocab.items(), key=lambda item: item[1]))
    unk = next(iter(vocab))
    vocab.pop(unk)
    vocab_sorted = "\n".join(vocab)
    vocab_size = len(vocab)
    gr.Info(f"Tokenizer vocab size: {vocab_size}")
    return checkpoint, vocab_size, unk, vocab_sorted

def load_tokenizer(checkpoint):
    if not "tokenizer" in globals():
        global tokenizer
    if len(checkpoint) > 0:
        try:
            tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        except Exception as error:
            gr.Warning("Unexpected error!")
            raise gr.Error(f"{error}")
    else:
        return ValueError("Tokenizer cannot be empty!")

def tokenize_er(checkpoint, sequence):
    try:
        load_tokenizer(checkpoint)
        tokens = tokenizer.tokenize(sequence)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        token_id_pair = []
        if len(tokens) == len(ids):
            for i in range(len(ids)):
                token_id_pair.append([tokens[i],ids[i]])
        return token_id_pair
    except NameError:
        gr.Warning("Select Tokenizer before sequencing.")
        return [[None, None]]
    except Exception as error:
        gr.Warning("Unexpected error!")
        raise gr.Error(f"{error}")

def de_tokenize_er(checkpoint, pairs):
    try:
        load_tokenizer(checkpoint)
        tokens = []
        ids = []
        for row in pairs:
            tokens.append(row[0])
            try:
                ids.append(int(row[1]))
            except:
                ids.append(0)
        tokens_ids= tokenizer.convert_tokens_to_ids(tokens)
        decoded_tokens = tokenizer.decode(tokens_ids)
        decoded_ids = tokenizer.decode(ids)
        return tokens_ids, decoded_tokens, decoded_ids
    except NameError:
        gr.Warning("Tokenize sequence before decoding.")
        return None, None, None
    except Exception as error:
        gr.Warning("Unexpected error!")
        raise gr.Error(f"{error}")

with gr.Blocks() as frontend:
    with gr.Row():
        with gr.Column(scale=3):
            gr.Markdown("# šŸ‡ Tokenizaminer\n### The Tokenizer Examiner, or the Tokeniza Miner... šŸ•µļøšŸ•³ļø\nThe purpose of this tool is to examine the vocabulary and tokens of a models tokenizer and play with the results.\nNote how the Vocabulary ID lines up with the full Vocabulary index on the right āž”ļø\n\nāš ļø Loading the full vocabulary can take a few seconds and the browser might stutter.")
            with gr.Row():
                gr.Markdown("\n#### 1. Select Tokenizer\nSelect from the list or enter any model from šŸ¤— Hugging Face Models, it will only download the Tokenizer data! Image models won't work here.")
            with gr.Row():
                input_checkpoint = gr.Dropdown(label="Tokenizer", choices=checkpoints, value=checkpoint, allow_custom_value=True, show_label=False, container=False)
                #btn_load_vocab = gr.Button(value="Load Vocabulary")
            with gr.Row():
                gr.Markdown("\n#### 2. Sequence & Tokenize")
            with gr.Row():
                input_sequence = gr.TextArea(label="Sequence", value=sequence, placeholder=placeholder, lines=3, interactive=True, show_label=False, container=False)
            with gr.Row():
                    btn_tokenize = gr.Button(value="Tokenize!")
                    btn_random_seq = gr.Button(value="Randomize!")
            with gr.Row():
                gr.Markdown("\n#### 3. Decode\nYou can select and edit each cell individually - then hit Decode!")
            with gr.Row():
                token_id_pair = gr.DataFrame(col_count=(2,"fixed"), headers=["Token","Vocabulary ID"], value=[[None,0]], type="array", datatype=["str", "number"], height=400, interactive=True)
            with gr.Row():
                btn_decode = gr.Button(value="Decode")
                btn_clear_pairs = gr.ClearButton(value="Clear Token/IDs", components=[token_id_pair])
            with gr.Row():
                with gr.Column():
                    output_decoded_token_ids = gr.TextArea(label="Re-encoded Tokens", interactive=False)
                    output_decoded_tokens = gr.TextArea(label="Decoded Re-encoded Tokens", interactive=False)
                with gr.Column():
                    output_decoded_ids = gr.TextArea(label="Decoded IDs", interactive=False)
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("###  šŸŽ² Tokenizer Data")
                output_checkpoint = gr.Textbox(visible=False)
                output_vocab_count = gr.Number(label="Vocab Size", interactive=False)
                output_token_zero = gr.Textbox(label="Token 0", interactive=False)
                output_vocab = gr.Code(label="Vocabulary IDs")

        input_checkpoint.change(fn=load_vocab, inputs=[input_checkpoint, output_checkpoint], outputs=[output_checkpoint, output_vocab_count, output_token_zero, output_vocab], queue=True)
        btn_tokenize.click(fn=tokenize_er, inputs=[input_checkpoint, input_sequence], outputs=[token_id_pair], queue=True)
        btn_random_seq.click(fn=randomize_sequence, inputs=[], outputs=[input_sequence])
        btn_decode.click(fn=de_tokenize_er, inputs=[input_checkpoint, token_id_pair], outputs=[output_decoded_token_ids,output_decoded_tokens, output_decoded_ids], queue=True)
    frontend.load(fn=load_vocab, inputs=[input_checkpoint, output_checkpoint], outputs=[output_checkpoint, output_vocab_count, output_token_zero, output_vocab], queue=True)

frontend.launch()