import plotly.graph_objs as go import textwrap import re from collections import defaultdict from paraphraser import generate_paraphrase from masking_methods import mask, mask_non_stopword def generate_plot(original_sentence): paraphrased_sentences = generate_paraphrase(original_sentence) first_paraphrased_sentence = paraphrased_sentences[0] masked_sentence = mask_non_stopword(first_paraphrased_sentence) masked_versions = mask(masked_sentence) nodes = [] nodes.append(original_sentence) nodes.extend(paraphrased_sentences) nodes.extend(masked_versions) nodes[0] += ' L0' para_len = len(paraphrased_sentences) for i in range(1, para_len+1): nodes[i] += ' L1' for i in range(para_len+1, len(nodes)): nodes[i] += ' L2' cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] wrapped_nodes = ['
'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes] def get_levels_and_edges(nodes): levels = {} edges = [] for i, node in enumerate(nodes): level = int(node.split()[-1][1]) levels[i] = level # Add edges from L0 to all L1 nodes root_node = next(i for i, level in levels.items() if level == 0) for i, level in levels.items(): if level == 1: edges.append((root_node, i)) # Identify the first L1 node first_l1_node = next(i for i, level in levels.items() if level == 1) # Add edges from the first L1 node to all L2 nodes for i, level in levels.items(): if level == 2: edges.append((first_l1_node, i)) return levels, edges # Get levels and dynamic edges levels, edges = get_levels_and_edges(nodes) max_level = max(levels.values()) # Calculate positions positions = {} level_widths = defaultdict(int) for node, level in levels.items(): level_widths[level] += 1 x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()} y_gap = 4 for node, level in levels.items(): positions[node] = (x_offsets[level], -level * y_gap) x_offsets[level] += 1 # Create figure fig = go.Figure() # Add nodes to the figure for i, node in enumerate(wrapped_nodes): x, y = positions[i] fig.add_trace(go.Scatter( x=[x], y=[y], mode='markers', marker=dict(size=10, color='blue'), hoverinfo='none' )) fig.add_annotation( x=x, y=y, text=node, showarrow=False, yshift=20, # Adjust the y-shift value to avoid overlap align="center", font=dict(size=10), bordercolor='black', borderwidth=1, borderpad=4, bgcolor='white', width=200 ) # Add edges to the figure for edge in edges: x0, y0 = positions[edge[0]] x1, y1 = positions[edge[1]] fig.add_trace(go.Scatter( x=[x0, x1], y=[y0, y1], mode='lines', line=dict(color='black', width=2) )) fig.update_layout( showlegend=False, margin=dict(t=50, b=50, l=50, r=50), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), width=1470, height=800 # Increase height to provide more space ) return fig