from bokeh.events import Tap
from bokeh.io import curdoc
from bokeh.layouts import column
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs
from bokeh.models.tools import CrosshairTool

from demo_utils import (
    get_data,
    prompt_boolq,
    pvp_colors,
    ctl_colors,
    clf_colors,
    reduct,
    task_best_pattern,
    plot_polygons_bokeh,
    advantage_text,
    data_difference,
    calculate_overlap,
    circ_easing,
    average_advantage_text,
    plot_three_polygons_bokeh,
    tasks,
    metric_tap,
    neutral_tasks, pattern_graph,
)
from text import text1, text2, text3, text4, initial_passage, initial_question, text5

########################################################################################################################
# Basic dimensions
########################################################################################################################

plot_width = 1200
plot_height = 400
sidebar_width = 400
in_text_plot_height = 300
text_width = 800
widget_size = 400

########################################################################################################################
# Patternification widget
########################################################################################################################

passage = TextAreaInput(title="Passage", rows=3, value=initial_passage, max_width=text_width)
passage.align = "center"
question = TextInput(title="Question", value=initial_question, max_width=text_width)
question.align = "center"
radio_button_group = RadioButtonGroup(labels=["Pattern 1", "Pattern 2", "Pattern 3"], active=0, max_width=text_width)
radio_button_group.align = "center"

box_style = {
    "display": "block",
    "margin": "0 auto",
    "width": f"{text_width}px",
    "text-align": "center",
    "white-space": "pre-wrap",
    "background": "#f4f4f4",
    "border": "1px solid #ddd",
    # "border-left": "3px solid #4d4945",
    "color": "#666",
    "page-break-inside": "avoid",
    # "font-family": "monospace",
    "font-size": "15px",
    "line-height": "1.6",
    "max-width": "100%",
    "overflow": "hidden",
    "min-height": "30px",
    "word-wrap": "break-word",
}

prompt_box = Div(
    text=prompt_boolq(passage.value, question.value, radio_button_group.active),
    width=text_width,
    style=box_style,
    sizing_mode="scale_width",
)
prompt_box.align = "center"


def update_prompt(attrname, old, new):
    prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active)


passage.on_change("value", update_prompt)
question.on_change("value", update_prompt)
radio_button_group.on_change("active", update_prompt)

patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width")
patternification.align = "center"

########################################################################################################################
# Advantage diagram
########################################################################################################################

advantage_plots_per_task = []
overlapping_range_per_task = []
training_points_per_task = []
clf_results_per_task = []
pvp_results_per_task = []
advantage_tabs = []
advantage_all_figures = Tabs(tabs=advantage_tabs)

advantage_box = Div(
    text="Click within the comparison region to compute the data advantage for a performance level",
    width=text_width,
    style=box_style,
    sizing_mode="scale_width",
)
advantage_box.align = "center"

for task in tasks:
    training_points, classifier_performances, pattern_performances = get_data(task)
    training_points_per_task.append(list(training_points))
    clf_results_per_task.append(reduct(classifier_performances, "accmax"))
    pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
    advantage_plots_per_task.append(plot_polygons_bokeh(
        task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
        pvp_colors
    ))
    advantage_plots_per_task[-1].align = "center"
    advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
    overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
    advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task))

    advantage_plots_per_task[-1].on_event(
        Tap,
        lambda event: metric_tap(
            event,
            overlapping_range_per_task[advantage_all_figures.active],
            training_points_per_task[advantage_all_figures.active],
            clf_results_per_task[advantage_all_figures.active],
            pvp_results_per_task[advantage_all_figures.active],
            advantage_box,
            advantage_plots_per_task[advantage_all_figures.active],
        ),
    )

    if task == "MNLI":
        training_points_per_task.append(list(training_points))
        clf_results_per_task.append(reduct(classifier_performances, "accmax"))
        pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
        advantage_plots_per_task.append(plot_polygons_bokeh(
            task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
            pvp_colors, x_log_scale=True
        ))
        advantage_plots_per_task[-1].align = "center"
        advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
        overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
        advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)"))

        advantage_plots_per_task[-1].on_event(
            Tap,
            lambda event: metric_tap(
                event,
                overlapping_range_per_task[advantage_all_figures.active],
                training_points_per_task[advantage_all_figures.active],
                clf_results_per_task[advantage_all_figures.active],
                pvp_results_per_task[advantage_all_figures.active],
                advantage_box,
                advantage_plots_per_task[advantage_all_figures.active],
            ),
        )

advantage_all_figures = Tabs(tabs=advantage_tabs)
advantage_all_figures.align = "center"


def on_integrate_click():
    frames = 200
    initial_placement = overlapping_range_per_task[advantage_all_figures.active][0]

    if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span):
        metric_line = Span(
            location=initial_placement,
            line_alpha=0.7,
            dimension="width",
            line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0],
            line_dash="dashed",
            line_width=1,
        )
        advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line])
    else:
        advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement
        advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
            0] if initial_placement < 0 else pvp_colors[0]

    average_advantage = 0
    for i in range(1, frames):
        metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + (
                overlapping_range_per_task[advantage_all_figures.active][1] -
                overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames)
        advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active],
                                          training_points_per_task[advantage_all_figures.active],
                                          clf_results_per_task[advantage_all_figures.active],
                                          pvp_results_per_task[advantage_all_figures.active])
        average_advantage = ((i - 1) * average_advantage + advantage_value) / i

        advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value
        advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
            0] if advantage_value < 0 else pvp_colors[0]
        advantage_box.text = average_advantage_text(average_advantage)


integrate = Button(width=175, max_width=175, label="Integrate over the whole region!")
integrate.align = "center"
integrate.on_click(on_integrate_click)


def on_tab_change(attr, old, new):
    advantage_box.text = "Click within the comparison region to compute the data advantage for a performance level"


advantage_all_figures.on_change('active', on_tab_change)

advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width")

########################################################################################################################
# Null verbalizer diagram
########################################################################################################################

null_tabs = []
null_all_figures = Tabs(tabs=null_tabs)

for task in neutral_tasks:
    training_points, classifier_performances, pattern_performances = get_data(task)
    training_points = list(training_points)
    clf_results = reduct(classifier_performances, "accmax")
    pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")
    ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral")
    null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
                                          pvp_colors, ctl_colors)
    null_plot.align = "center"
    null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
    null_tabs.append(Panel(child=null_plot, title=task))

    if task == "MNLI":
        null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
                                              pvp_colors, ctl_colors, x_log_scale=True)
        null_plot.align = "center"
        null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
        null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)"))

null_all_figures = Tabs(tabs=null_tabs)
null_all_figures.align = "center"

########################################################################################################################
# Patterns diagram
########################################################################################################################

pattern_tabs = []
pattern_all_figures = Tabs(tabs=pattern_tabs)

for task in tasks:
    pattern_plot = pattern_graph(task)
    pattern_plot.align = "center"
    pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
    pattern_tabs.append(Panel(child=pattern_plot, title=task))

pattern_all_figures = Tabs(tabs=pattern_tabs)
pattern_all_figures.align = "center"

########################################################################################################################
# Add write-up text
########################################################################################################################

main_text_style = {
    "min-height": "100px",
    "overflow": "hidden",
    "display": "block",
    "margin": "auto",
    "width": f"{text_width}px",
    "font-size": "18px",
}

textbox1 = Div(text=text1, style=main_text_style)
textbox2 = Div(text=text2, style=main_text_style)
textbox3 = Div(text=text3, style=main_text_style)
textbox4 = Div(text=text4, style=main_text_style)
textbox5 = Div(text=text5, style=main_text_style)
textbox1.align = "center"
textbox2.align = "center"
textbox3.align = "center"
textbox4.align = "center"
textbox5.align = "center"

########################################################################################################################
# Set up layouts and add to document
########################################################################################################################

main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures,
                   textbox5, sizing_mode="scale_width")
main_body.align = "center"

curdoc().add_root(main_body)
curdoc().title = "How many data points is a prompt worth ?"