|
import plotly.graph_objects as go |
|
import textwrap |
|
import re |
|
from collections import defaultdict |
|
|
|
def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info, common_grams): |
|
|
|
nodes = [paraphrased_sentence] + scheme_sentences |
|
nodes[0] += ' L0' |
|
for i in range(1, len(nodes)): |
|
nodes[i] += ' L1' |
|
|
|
|
|
def apply_lcs_numbering(sentence, common_grams): |
|
for idx, lcs in common_grams: |
|
|
|
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) |
|
return sentence |
|
|
|
|
|
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] |
|
|
|
|
|
def highlight_words(sentence, color_map): |
|
for word, color in color_map.items(): |
|
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) |
|
return sentence |
|
|
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
global_color_map = dict(highlight_info) |
|
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes] |
|
|
|
|
|
def get_levels_and_edges(nodes): |
|
levels = {} |
|
edges = [] |
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
root_node = next(i for i, level in levels.items() if level == 0) |
|
for i, level in levels.items(): |
|
if level == 1: |
|
edges.append((root_node, i)) |
|
|
|
return levels, edges |
|
|
|
|
|
levels, edges = get_levels_and_edges(nodes) |
|
max_level = max(levels.values(), default=0) |
|
|
|
|
|
positions = {} |
|
level_heights = defaultdict(int) |
|
for node, level in levels.items(): |
|
level_heights[level] += 1 |
|
|
|
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} |
|
x_gap = 2 |
|
l1_y_gap = 10 |
|
|
|
for node, level in levels.items(): |
|
if level == 1: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
else: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
y_offsets[level] += 1 |
|
|
|
|
|
def color_highlighted_words(node, color_map): |
|
parts = re.split(r'(\{\{.*?\}\})', node) |
|
colored_parts = [] |
|
for part in parts: |
|
match = re.match(r'\{\{(.*?)\}\}', part) |
|
if match: |
|
word = match.group(1) |
|
color = color_map.get(word, 'black') |
|
colored_parts.append(f"<span style='color: {color};'>{word}</span>") |
|
else: |
|
colored_parts.append(part) |
|
return ''.join(colored_parts) |
|
|
|
|
|
edge_texts = [ |
|
"Highest Entropy Masking", |
|
"Pseudo-random Masking", |
|
"Random Masking", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling" |
|
] |
|
|
|
|
|
fig1 = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
colored_node = color_highlighted_words(node, global_color_map) |
|
x, y = positions[i] |
|
fig1.add_trace(go.Scatter( |
|
x=[-x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=10, color='blue'), |
|
hoverinfo='none' |
|
)) |
|
fig1.add_annotation( |
|
x=-x, |
|
y=y, |
|
text=colored_node, |
|
showarrow=False, |
|
xshift=15, |
|
align="center", |
|
font=dict(size=12), |
|
bordercolor='black', |
|
borderwidth=1, |
|
borderpad=2, |
|
bgcolor='white', |
|
width=300, |
|
height=120 |
|
) |
|
|
|
|
|
for i, edge in enumerate(edges): |
|
x0, y0 = positions[edge[0]] |
|
x1, y1 = positions[edge[1]] |
|
fig1.add_trace(go.Scatter( |
|
x=[-x0, -x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=1) |
|
)) |
|
|
|
|
|
mid_x = (-x0 + -x1) / 2 |
|
mid_y = (y0 + y1) / 2 |
|
|
|
|
|
text_y_position = mid_y + 0.8 |
|
|
|
|
|
fig1.add_annotation( |
|
x=mid_x, |
|
y=text_y_position, |
|
text=edge_texts[i], |
|
showarrow=False, |
|
font=dict(size=12), |
|
align="center" |
|
) |
|
|
|
fig1.update_layout( |
|
showlegend=False, |
|
margin=dict(t=20, b=20, l=20, r=20), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=1435, |
|
height=1000 |
|
) |
|
|
|
return fig1 |
|
|
|
def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info, common_grams): |
|
|
|
nodes = scheme_sentences + sampled_sentence |
|
para_len = len(scheme_sentences) |
|
|
|
|
|
for i in range(para_len): |
|
nodes[i] += ' L0' |
|
for i in range(para_len, len(nodes)): |
|
nodes[i] += ' L1' |
|
|
|
|
|
def apply_lcs_numbering(sentence, common_grams): |
|
for idx, lcs in common_grams: |
|
|
|
sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence) |
|
return sentence |
|
|
|
|
|
nodes = [apply_lcs_numbering(node, common_grams) for node in nodes] |
|
|
|
|
|
def highlight_words(sentence, color_map): |
|
for word, color in color_map.items(): |
|
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) |
|
return sentence |
|
|
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
global_color_map = dict(highlight_info) |
|
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes] |
|
|
|
|
|
def get_levels_and_edges(nodes): |
|
levels = {} |
|
edges = [] |
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
l0_indices = [i for i, level in levels.items() if level == 0] |
|
l1_indices = [i for i, level in levels.items() if level == 1] |
|
|
|
|
|
if len(l0_indices) < 3: |
|
raise ValueError("There should be exactly 3 L0 nodes to attach edges correctly.") |
|
|
|
|
|
for i, l1_node in enumerate(l1_indices): |
|
if i < 4: |
|
edges.append((l0_indices[0], l1_node)) |
|
elif i < 8: |
|
edges.append((l0_indices[1], l1_node)) |
|
else: |
|
edges.append((l0_indices[2], l1_node)) |
|
|
|
return levels, edges |
|
|
|
|
|
levels, edges = get_levels_and_edges(nodes) |
|
max_level = max(levels.values(), default=0) |
|
|
|
|
|
positions = {} |
|
level_heights = defaultdict(int) |
|
for node, level in levels.items(): |
|
level_heights[level] += 1 |
|
|
|
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} |
|
x_gap = 2 |
|
l1_y_gap = 10 |
|
|
|
for node, level in levels.items(): |
|
if level == 1: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
else: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
y_offsets[level] += 1 |
|
|
|
|
|
def color_highlighted_words(node, color_map): |
|
parts = re.split(r'(\{\{.*?\}\})', node) |
|
colored_parts = [] |
|
for part in parts: |
|
match = re.match(r'\{\{(.*?)\}\}', part) |
|
if match: |
|
word = match.group(1) |
|
color = color_map.get(word, 'black') |
|
colored_parts.append(f"<span style='color: {color};'>{word}</span>") |
|
else: |
|
colored_parts.append(part) |
|
return ''.join(colored_parts) |
|
|
|
|
|
edge_texts = [ |
|
"Highest Entropy Masking", |
|
"Pseudo-random Masking", |
|
"Random Masking", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling", |
|
"Greedy Sampling", |
|
"Temperature Sampling", |
|
"Exponential Minimum Sampling", |
|
"Inverse Transform Sampling" |
|
] |
|
|
|
|
|
fig2 = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
colored_node = color_highlighted_words(node, global_color_map) |
|
x, y = positions[i] |
|
fig2.add_trace(go.Scatter( |
|
x=[-x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=10, color='blue'), |
|
hoverinfo='none' |
|
)) |
|
fig2.add_annotation( |
|
x=-x, |
|
y=y, |
|
text=colored_node, |
|
showarrow=False, |
|
xshift=15, |
|
align="center", |
|
font=dict(size=12), |
|
bordercolor='black', |
|
borderwidth=1, |
|
borderpad=2, |
|
bgcolor='white', |
|
width=450, |
|
height=65 |
|
) |
|
|
|
|
|
for i, edge in enumerate(edges): |
|
x0, y0 = positions[edge[0]] |
|
x1, y1 = positions[edge[1]] |
|
fig2.add_trace(go.Scatter( |
|
x=[-x0, -x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=1) |
|
)) |
|
|
|
|
|
mid_x = (-x0 + -x1) / 2 |
|
mid_y = (y0 + y1) / 2 |
|
|
|
|
|
text_y_position = mid_y + 0.8 |
|
|
|
|
|
|
|
text = edge_texts[i] if i < len(edge_texts) else f"Edge {i+1}" |
|
fig2.add_annotation( |
|
x=mid_x, |
|
y=text_y_position, |
|
text=text, |
|
showarrow=False, |
|
font=dict(size=12), |
|
align="center" |
|
) |
|
|
|
fig2.update_layout( |
|
showlegend=False, |
|
margin=dict(t=20, b=20, l=20, r=20), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=1435, |
|
height=1000 |
|
) |
|
|
|
return fig2 |