import os
import random
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from datasets import load_dataset

dataset_names = [
    "AI4Code",
    "AMPS",
    "ASFPublicMail",
    "CPDataset",
    "DMMath",
    "Discourse",
    "Enwiki",
    "EuroParliamentProceedings",
    "FreeLaw_Options",
    "GithubDiff",
    "GithubIssues",
    "Gutenberg",
    "LeetCode",
    "PileOfLaw",
    "PubMed",
    "S2ORC",
    "StackExchange",
    "USENET",
    "USPTO",
    "UbuntuIRC",
    "arXiv",
]

dataset_data = {}
for name in dataset_names:
    path = f"data/{name}/data.json"
    ds = load_dataset(
        "CarperAI/pilev2_smol_metadata",
        data_files=path,
        use_auth_token=os.environ["HF_TOKEN"],
        split="train",
        # download_mode="force_redownload",
    )
    dataset_data[name] = {
        "ds": ds,
        "check_word_number_criteria": np.array(ds["check_word_number_criteria"]),
        "check_char_repetition_criteria": np.array(ds["check_char_repetition_criteria"]),
        "check_flagged_words_criteria": np.array(ds["check_flagged_words_criteria"]),
        "check_stop_word_ratio_criteria": np.array(ds["check_stop_word_ratio_criteria"]),
        "check_perplexity_criteria": np.array(ds["check_perplexity_criteria"]),
        "check_compression_ratio_criteria": np.array(ds["check_compression_ratio_criteria"]),
    }

def plt_plot(criteria, dataset, threshold, greater_than=True):
    plt.close("all")
    x = dataset_data[dataset][criteria]
    # calculate percentage of data that will be removed given threshold
    perc = np.sum(x > threshold if greater_than else x < threshold) / len(x)
    # create a figure
    fig = plt.figure()
    # add a subplot
    ax = fig.add_subplot(111)
    # plot some data using black
    ax.hist(x, bins=50, color="black")
    # plot red dashed line at threshold
    ax.axvline(threshold, color='r', linestyle='dashed', linewidth=2)
    # set title
    # add percentage of data removed
    ax.set_title(f"{dataset} (removed {perc:.2%})")
    plt.xlabel("Value")
    plt.ylabel("Frequency")
    # make it look nice
    plt.tight_layout()
    return fig

def check_filtered(criteria, dataset, threshold, greater_than=True):
    ds = dataset_data[dataset]["ds"]

    filtered_ds = ds.filter(
        lambda x: x[criteria] > threshold if greater_than else x[criteria] < threshold
    )
    if len(filtered_ds) == 0:
        return "No examples found"
    # get random sample of 1
    sample = filtered_ds.select([random.randint(0, len(filtered_ds) - 1)])["text"][0]

    return sample

with gr.Blocks() as demo:
    dataset = gr.Radio(dataset_names, label="Dataset", value="arXiv")

    with gr.Tab("Character Repetition Criteria"):
        # plot some random data
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_char_repetition_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_char_repetition_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)

    with gr.Tab("Number of Words Criteria"):
        # plot some random data
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=50_000, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_word_number_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_word_number_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)

    with gr.Tab("Character Repetition Criteria"):
        # plot some random data
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_char_repetition_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_char_repetition_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)
    
    with gr.Tab("Stop Word Ratio Criteria"):
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_stop_word_ratio_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_stop_word_ratio_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)
    
    with gr.Tab("Flagged Word Criteria"):
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_flagged_words_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_flagged_words_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)

    with gr.Tab("Perplexity Criteria"):
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=50_000, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(plt_plot, "check_perplexity_criteria")
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(check_filtered, "check_perplexity_criteria")
        check.click(check_fn, [dataset, threshold], filtered_data)
    
    with gr.Tab("Compression Ratio Criteria"):
        plot = gr.Plot()
        threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
        calculate = gr.Button("Calculate")
        check = gr.Button("Check Filtered Data")
        filtered_data = gr.Textbox(lines=5, label="Filtered Data")
        plot_fn = partial(
            plt_plot,
            "check_compression_ratio_criteria",
            greater_than=False
        )
        calculate.click(plot_fn, [dataset, threshold], plot)
        check_fn = partial(
            check_filtered,
            "check_compression_ratio_criteria",
            greater_than=False
        )
        check.click(check_fn, [dataset, threshold], filtered_data)

if __name__ == "__main__":
    demo.launch()