import io
import json
import re

import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer

tokenizers = {
    "bert": "google-bert/bert-base-uncased",
    "bloom": "bigscience/bloom-560m",
    "gemma": "fxmarty/tiny-random-GemmaForCausalLM",
    "chatglm3": "THUDM/chatglm3-6b",
    "falcon": "tiiuae/falcon-7b",
    "gpt-neox": "EleutherAI/gpt-neox-20b",
    "llama": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
    "magicoder": "ise-uiuc/Magicoder-S-DS-6.7B",
    "mistral": "mistralai/Mistral-7B-v0.1",
    "opt": "facebook/opt-2.7b",
    "phi-2": "microsoft/phi-2",
    "pythia": "EleutherAI/pythia-1.4b-deduped",
    "roberta": "FacebookAI/roberta-base",
    "qwen": "Qwen/Qwen1.5-7B-Chat",
    "starcoder": "bigcode/starcoder2-7b",
    "t5": "google-t5/t5-base",
}

tokenizers = list(tokenizers.values())

def plot_histogram(data):
    plt.hist(data)
    plt.title("Histogram of number of tokens per dataset item")
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close()
    buf.seek(0)
    im = Image.open(buf)
    return im


def count(model_id, dataset_id, config, split, column, add_special_tokens=True):
    tokencounter = []
    wordcounter = []
    charcounter = []
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if config == "":
        config is None
    dataset = load_dataset(dataset_id, config, split=split, trust_remote_code=True)
    pattern = r"[a-zA-Z]+"
    for item in dataset:
        tokens = tokenizer(item[column], add_special_tokens=add_special_tokens)["input_ids"]
        tokencounter.append(len(tokens))
        charcounter.append(len(item[column]))
        # not 100% accurate but good enough
        words = re.findall(pattern, item[column])
        wordcounter.append(len(words))

    df = pd.DataFrame(tokencounter).describe().T
    df.insert(0, "type", "tokens")
    dfc = pd.DataFrame(charcounter).describe().T
    dfc.insert(0, "type", "chars")
    dfw = pd.DataFrame(wordcounter).describe().T
    dfw.insert(0, "type", "words")
    df.loc[-1] = dfw.values[0]
    df.index = df.index + 1  # shifting index
    df.loc[-1] = dfc.values[0]
    df = df.round(1)
    df.drop("count", axis=1, inplace=True)

    return plot_histogram(tokencounter), df


demo = gr.Interface(
    fn=count,
    title="Dataset token counts and distribution",
    inputs=[
        gr.Dropdown(label="Tokenizer", choices=tokenizers, allow_custom_value=True),
        gr.Textbox(label="Dataset"),
        gr.Textbox(label="Config"),
        gr.Textbox(label="Split"),
        gr.Textbox(label="Column"),
        gr.Checkbox(label="Add special tokens", value=True),
    ],
    outputs=[
        gr.Image(),
        gr.Dataframe(label="Token, word and character counts per dataset item"),
    ],
    examples=[
        ["mistralai/Mistral-7B-v0.1", "gsarti/flores_101", "eng", "dev", "sentence"],
        ["mistralai/Mistral-7B-v0.1", "Muennighoff/flores200", "eng_Latn", "dev", "sentence"],
        ["mistralai/Mistral-7B-v0.1", "hails/mmlu_no_train", "elementary_mathematics", "test", "question"],
        ["mistralai/Mistral-7B-v0.1", "gsm8k", "main", "test", "question"],
        ["mistralai/Mistral-7B-v0.1", "locuslab/TOFU", "world_facts", "train", "question"],
        ["mistralai/Mistral-7B-v0.1", "imdb", "", "test", "text"],
        ["mistralai/Mistral-7B-v0.1", "wikitext", "wikitext-2-v1", "validation", "text"],
        ["mistralai/Mistral-7B-v0.1", "zeroshot/twitter-financial-news-sentiment", "", "validation", "text"],
    ],
    cache_examples=True
)

demo.launch()