MatrixYao's picture
fix diagram (#4)
96ab487
from bokeh.events import Tap
from bokeh.io import curdoc
from bokeh.layouts import column
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs
from bokeh.models.tools import CrosshairTool
from demo_utils import (
get_data,
prompt_boolq,
pvp_colors,
ctl_colors,
clf_colors,
reduct,
task_best_pattern,
plot_polygons_bokeh,
advantage_text,
data_difference,
calculate_overlap,
circ_easing,
average_advantage_text,
plot_three_polygons_bokeh,
tasks,
metric_tap,
neutral_tasks, pattern_graph,
)
from text import text1, text2, text3, text4, initial_passage, initial_question, text5
########################################################################################################################
# Basic dimensions
########################################################################################################################
plot_width = 1200
plot_height = 400
sidebar_width = 400
in_text_plot_height = 300
text_width = 800
widget_size = 400
########################################################################################################################
# Patternification widget
########################################################################################################################
passage = TextAreaInput(title="篇章", rows=3, value=initial_passage, max_width=text_width)
passage.align = "center"
question = TextInput(title="问题", value=initial_question, max_width=text_width)
question.align = "center"
radio_button_group = RadioButtonGroup(labels=["模板 1", "模板 2", "模板 3"], active=0, max_width=text_width)
radio_button_group.align = "center"
box_style = {
"display": "block",
"margin": "0 auto",
"width": f"{text_width}px",
"text-align": "center",
"white-space": "pre-wrap",
"background": "#f4f4f4",
"border": "1px solid #ddd",
# "border-left": "3px solid #4d4945",
"color": "#666",
"page-break-inside": "avoid",
# "font-family": "monospace",
"font-size": "15px",
"line-height": "1.6",
"max-width": "100%",
"overflow": "hidden",
"min-height": "30px",
"word-wrap": "break-word",
}
prompt_box = Div(
text=prompt_boolq(passage.value, question.value, radio_button_group.active),
width=text_width,
style=box_style,
sizing_mode="scale_width",
)
prompt_box.align = "center"
def update_prompt(attrname, old, new):
prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active)
passage.on_change("value", update_prompt)
question.on_change("value", update_prompt)
radio_button_group.on_change("active", update_prompt)
patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width")
patternification.align = "center"
########################################################################################################################
# Advantage diagram
########################################################################################################################
advantage_plots_per_task = []
overlapping_range_per_task = []
training_points_per_task = []
clf_results_per_task = []
pvp_results_per_task = []
advantage_tabs = []
advantage_all_figures = Tabs(tabs=advantage_tabs)
advantage_box = Div(
text="在比较区域内点击某点以计算该点对应的性能点上的数据优势",
width=text_width,
style=box_style,
sizing_mode="scale_width",
)
advantage_box.align = "center"
for task in tasks:
training_points, classifier_performances, pattern_performances = get_data(task)
training_points_per_task.append(list(training_points))
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
advantage_plots_per_task.append(plot_polygons_bokeh(
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
pvp_colors
))
advantage_plots_per_task[-1].align = "center"
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task))
advantage_plots_per_task[-1].on_event(
Tap,
lambda event: metric_tap(
event,
overlapping_range_per_task[advantage_all_figures.active],
training_points_per_task[advantage_all_figures.active],
clf_results_per_task[advantage_all_figures.active],
pvp_results_per_task[advantage_all_figures.active],
advantage_box,
advantage_plots_per_task[advantage_all_figures.active],
),
)
if task == "MNLI":
training_points_per_task.append(list(training_points))
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
advantage_plots_per_task.append(plot_polygons_bokeh(
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
pvp_colors, x_log_scale=True
))
advantage_plots_per_task[-1].align = "center"
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)"))
advantage_plots_per_task[-1].on_event(
Tap,
lambda event: metric_tap(
event,
overlapping_range_per_task[advantage_all_figures.active],
training_points_per_task[advantage_all_figures.active],
clf_results_per_task[advantage_all_figures.active],
pvp_results_per_task[advantage_all_figures.active],
advantage_box,
advantage_plots_per_task[advantage_all_figures.active],
),
)
advantage_all_figures = Tabs(tabs=advantage_tabs)
advantage_all_figures.align = "center"
def on_integrate_click():
frames = 200
initial_placement = overlapping_range_per_task[advantage_all_figures.active][0]
if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span):
metric_line = Span(
location=initial_placement,
line_alpha=0.7,
dimension="width",
line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0],
line_dash="dashed",
line_width=1,
)
advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line])
else:
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
0] if initial_placement < 0 else pvp_colors[0]
average_advantage = 0
for i in range(1, frames):
metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + (
overlapping_range_per_task[advantage_all_figures.active][1] -
overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames)
advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active],
training_points_per_task[advantage_all_figures.active],
clf_results_per_task[advantage_all_figures.active],
pvp_results_per_task[advantage_all_figures.active])
average_advantage = ((i - 1) * average_advantage + advantage_value) / i
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
0] if advantage_value < 0 else pvp_colors[0]
advantage_box.text = average_advantage_text(average_advantage)
integrate = Button(width=175, max_width=175, label="对整个区域进行积分!")
integrate.align = "center"
integrate.on_click(on_integrate_click)
def on_tab_change(attr, old, new):
advantage_box.text = "在比较区域内点击某点以计算该点对应的性能点上的数据优势"
advantage_all_figures.on_change('active', on_tab_change)
advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width")
########################################################################################################################
# Null verbalizer diagram
########################################################################################################################
null_tabs = []
null_all_figures = Tabs(tabs=null_tabs)
for task in neutral_tasks:
training_points, classifier_performances, pattern_performances = get_data(task)
training_points = list(training_points)
clf_results = reduct(classifier_performances, "accmax")
pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")
ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral")
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
pvp_colors, ctl_colors)
null_plot.align = "center"
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
null_tabs.append(Panel(child=null_plot, title=task))
if task == "MNLI":
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
pvp_colors, ctl_colors, x_log_scale=True)
null_plot.align = "center"
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)"))
null_all_figures = Tabs(tabs=null_tabs)
null_all_figures.align = "center"
########################################################################################################################
# Patterns diagram
########################################################################################################################
pattern_tabs = []
pattern_all_figures = Tabs(tabs=pattern_tabs)
for task in tasks:
pattern_plot = pattern_graph(task)
pattern_plot.align = "center"
pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
pattern_tabs.append(Panel(child=pattern_plot, title=task))
pattern_all_figures = Tabs(tabs=pattern_tabs)
pattern_all_figures.align = "center"
########################################################################################################################
# Add write-up text
########################################################################################################################
main_text_style = {
"min-height": "100px",
"overflow": "hidden",
"display": "block",
"margin": "auto",
"width": f"{text_width}px",
"font-size": "18px",
}
textbox1 = Div(text=text1, style=main_text_style)
textbox2 = Div(text=text2, style=main_text_style)
textbox3 = Div(text=text3, style=main_text_style)
textbox4 = Div(text=text4, style=main_text_style)
textbox5 = Div(text=text5, style=main_text_style)
textbox1.align = "center"
textbox2.align = "center"
textbox3.align = "center"
textbox4.align = "center"
textbox5.align = "center"
########################################################################################################################
# Set up layouts and add to document
########################################################################################################################
main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures, textbox5, sizing_mode="scale_width")
main_body.align = "center"
curdoc().add_root(main_body)
curdoc().title = "一条提示抵得上多少样本数据?"