Upload folder using huggingface_hub
Browse files- app.py +93 -19
- entailment.py +1 -1
- highlighter.py +33 -42
- lcs.py +3 -3
- masking_methods.py +84 -12
- paraphraser.py +1 -1
- sampling_methods.py +31 -139
- tree.py +90 -47
app.py
CHANGED
@@ -6,7 +6,6 @@ import plotly.graph_objs as go
|
|
6 |
import textwrap
|
7 |
from transformers import pipeline
|
8 |
import re
|
9 |
-
import time
|
10 |
import requests
|
11 |
from PIL import Image
|
12 |
import itertools
|
@@ -20,10 +19,7 @@ import pandas as pd
|
|
20 |
from pprint import pprint
|
21 |
from tenacity import retry
|
22 |
from tqdm import tqdm
|
23 |
-
import scipy.stats
|
24 |
-
import torch
|
25 |
from transformers import GPT2LMHeadModel
|
26 |
-
import seaborn as sns
|
27 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
|
28 |
import random
|
29 |
from nltk.corpus import stopwords
|
@@ -31,22 +27,92 @@ from termcolor import colored
|
|
31 |
from nltk.translate.bleu_score import sentence_bleu
|
32 |
from transformers import BertTokenizer, BertModel
|
33 |
import gradio as gr
|
34 |
-
from tree import
|
35 |
from paraphraser import generate_paraphrase
|
36 |
from lcs import find_common_subsequences
|
37 |
from highlighter import highlight_common_words, highlight_common_words_dict
|
38 |
from entailment import analyze_entailment
|
|
|
|
|
|
|
39 |
|
40 |
# Function for the Gradio interface
|
41 |
def model(prompt):
|
42 |
-
|
43 |
-
paraphrased_sentences = generate_paraphrase(
|
44 |
-
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
@@ -63,15 +129,23 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
63 |
highlighted_user_prompt = gr.HTML()
|
64 |
|
65 |
with gr.Row():
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
with gr.Row():
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt,
|
72 |
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
73 |
-
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt,
|
74 |
|
75 |
# Launch the demo
|
76 |
-
demo.launch(share=True)
|
77 |
-
|
|
|
6 |
import textwrap
|
7 |
from transformers import pipeline
|
8 |
import re
|
|
|
9 |
import requests
|
10 |
from PIL import Image
|
11 |
import itertools
|
|
|
19 |
from pprint import pprint
|
20 |
from tenacity import retry
|
21 |
from tqdm import tqdm
|
|
|
|
|
22 |
from transformers import GPT2LMHeadModel
|
|
|
23 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
|
24 |
import random
|
25 |
from nltk.corpus import stopwords
|
|
|
27 |
from nltk.translate.bleu_score import sentence_bleu
|
28 |
from transformers import BertTokenizer, BertModel
|
29 |
import gradio as gr
|
30 |
+
from tree import generate_subplot
|
31 |
from paraphraser import generate_paraphrase
|
32 |
from lcs import find_common_subsequences
|
33 |
from highlighter import highlight_common_words, highlight_common_words_dict
|
34 |
from entailment import analyze_entailment
|
35 |
+
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
|
36 |
+
from sampling_methods import sample_word
|
37 |
+
|
38 |
|
39 |
# Function for the Gradio interface
|
40 |
def model(prompt):
|
41 |
+
user_prompt = prompt
|
42 |
+
paraphrased_sentences = generate_paraphrase(user_prompt)
|
43 |
+
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
|
44 |
+
length_accepted_sentences = len(selected_sentences)
|
45 |
+
common_grams = find_common_subsequences(user_prompt, selected_sentences)
|
46 |
+
|
47 |
+
masked_sentences = []
|
48 |
+
masked_words = []
|
49 |
+
masked_logits = []
|
50 |
+
selected_sentences_list = list(selected_sentences.keys())
|
51 |
+
|
52 |
+
for sentence in selected_sentences_list:
|
53 |
+
# Mask non-stopword
|
54 |
+
masked_sent, logits, words = mask_non_stopword(sentence)
|
55 |
+
masked_sentences.append(masked_sent)
|
56 |
+
masked_words.append(words)
|
57 |
+
masked_logits.append(logits)
|
58 |
+
|
59 |
+
# Mask non-stopword pseudorandom
|
60 |
+
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
|
61 |
+
masked_sentences.append(masked_sent)
|
62 |
+
masked_words.append(words)
|
63 |
+
masked_logits.append(logits)
|
64 |
+
|
65 |
+
# High entropy words
|
66 |
+
masked_sent, logits, words = high_entropy_words(sentence, common_grams)
|
67 |
+
masked_sentences.append(masked_sent)
|
68 |
+
masked_words.append(words)
|
69 |
+
masked_logits.append(logits)
|
70 |
+
|
71 |
+
sampled_sentences = []
|
72 |
+
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
|
73 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
|
74 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
|
75 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
|
76 |
+
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
|
77 |
+
|
78 |
+
# Predefined set of colors that are visible on a white background, excluding black
|
79 |
+
colors = ["red", "blue", "brown", "green"]
|
80 |
+
|
81 |
+
# Function to generate color from predefined set
|
82 |
+
def select_color():
|
83 |
+
return random.choice(colors)
|
84 |
+
|
85 |
+
# Create highlight_info with selected colors
|
86 |
+
highlight_info = [(word, select_color()) for _, word in common_grams]
|
87 |
+
|
88 |
+
|
89 |
+
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "User Prompt (Highlighted and Numbered)")
|
90 |
+
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
|
91 |
+
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
|
92 |
+
|
93 |
+
# Initialize empty list to hold the trees
|
94 |
+
trees = []
|
95 |
+
|
96 |
+
# Initialize the indices for masked and sampled sentences
|
97 |
+
masked_index = 0
|
98 |
+
sampled_index = 0
|
99 |
+
|
100 |
+
for i, sentence in enumerate(selected_sentences):
|
101 |
+
# Generate the sublists of masked and sampled sentences based on current indices
|
102 |
+
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
|
103 |
+
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
|
104 |
+
|
105 |
+
# Create the tree for the current sentence
|
106 |
+
tree = generate_subplot(sentence, next_masked_sentences, next_sampled_sentences, highlight_info)
|
107 |
+
trees.append(tree)
|
108 |
+
|
109 |
+
# Update the indices for the next iteration
|
110 |
+
masked_index += 3
|
111 |
+
sampled_index += 12
|
112 |
+
|
113 |
+
|
114 |
+
# Return all the outputs together
|
115 |
+
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees
|
116 |
|
117 |
|
118 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
|
129 |
highlighted_user_prompt = gr.HTML()
|
130 |
|
131 |
with gr.Row():
|
132 |
+
with gr.Tabs():
|
133 |
+
with gr.TabItem("Paraphrased Sentences"):
|
134 |
+
highlighted_accepted_sentences = gr.HTML()
|
135 |
+
with gr.TabItem("Discarded Sentences"):
|
136 |
+
highlighted_discarded_sentences = gr.HTML()
|
137 |
|
138 |
with gr.Row():
|
139 |
+
with gr.Tabs():
|
140 |
+
tree_tabs = []
|
141 |
+
for i in range(3): # Adjust this range according to the number of trees
|
142 |
+
with gr.TabItem(f"Tree {i+1}"):
|
143 |
+
tree = gr.Plot()
|
144 |
+
tree_tabs.append(tree)
|
145 |
|
146 |
+
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
|
147 |
clear_button.click(lambda: "", inputs=None, outputs=user_input)
|
148 |
+
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
|
149 |
|
150 |
# Launch the demo
|
151 |
+
demo.launch(share=True)
|
|
entailment.py
CHANGED
@@ -28,4 +28,4 @@ def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
|
|
28 |
|
29 |
return all_sentences, selected_sentences, discarded_sentences
|
30 |
|
31 |
-
print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
|
|
|
28 |
|
29 |
return all_sentences, selected_sentences, discarded_sentences
|
30 |
|
31 |
+
# print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
|
highlighter.py
CHANGED
@@ -39,57 +39,48 @@ def highlight_common_words(common_words, sentences, title):
|
|
39 |
'''
|
40 |
|
41 |
|
|
|
42 |
import re
|
43 |
|
44 |
-
def highlight_common_words_dict(common_words,
|
45 |
color_map = {}
|
46 |
color_index = 0
|
47 |
highlighted_html = []
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
highlighted_sentences = [f'<h4 style="color: #374151; margin-bottom: 5px;">{section_title}</h4>']
|
53 |
|
54 |
-
for
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
f'</span>'
|
73 |
-
),
|
74 |
-
highlighted_sentence,
|
75 |
-
flags=re.IGNORECASE
|
76 |
-
)
|
77 |
-
highlighted_sentences.append(
|
78 |
-
f'<div style="margin-bottom: 5px;">'
|
79 |
-
f'{highlighted_sentence}'
|
80 |
-
f'<div style="display: inline-block; margin-left: 5px; border: 1px solid #ddd; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
|
81 |
-
f'Entailment Score: {score}</div></div>'
|
82 |
)
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
-
final_html = "<br>".join(
|
90 |
return f'''
|
91 |
-
<div style="
|
92 |
-
<h3 style="margin-top: 0; font-size: 1em; color: #111827;
|
93 |
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
94 |
</div>
|
95 |
-
'''
|
|
|
39 |
'''
|
40 |
|
41 |
|
42 |
+
|
43 |
import re
|
44 |
|
45 |
+
def highlight_common_words_dict(common_words, sentences, title):
|
46 |
color_map = {}
|
47 |
color_index = 0
|
48 |
highlighted_html = []
|
49 |
|
50 |
+
for idx, (sentence, score) in enumerate(sentences.items(), start=1):
|
51 |
+
sentence_with_idx = f"{idx}. {sentence}"
|
52 |
+
highlighted_sentence = sentence_with_idx
|
|
|
53 |
|
54 |
+
for index, word in common_words:
|
55 |
+
if word not in color_map:
|
56 |
+
color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
|
57 |
+
color_index += 1
|
58 |
+
escaped_word = re.escape(word)
|
59 |
+
pattern = rf'\b{escaped_word}\b'
|
60 |
+
highlighted_sentence = re.sub(
|
61 |
+
pattern,
|
62 |
+
lambda m, idx=index, color=color_map[word]: (
|
63 |
+
f'<span style="background-color: {color}; font-weight: bold;'
|
64 |
+
f' padding: 1px 2px; border-radius: 2px; position: relative;">'
|
65 |
+
f'<span style="background-color: black; color: white; border-radius: 50%;'
|
66 |
+
f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
|
67 |
+
f'{m.group(0)}'
|
68 |
+
f'</span>'
|
69 |
+
),
|
70 |
+
highlighted_sentence,
|
71 |
+
flags=re.IGNORECASE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
+
highlighted_html.append(
|
74 |
+
f'<div style="margin-bottom: 5px;">'
|
75 |
+
f'{highlighted_sentence}'
|
76 |
+
f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
|
77 |
+
f'Entailment Score: {score}</div></div>'
|
78 |
+
)
|
79 |
|
80 |
+
final_html = "<br>".join(highlighted_html)
|
81 |
return f'''
|
82 |
+
<div style="background-color: #ffffff; color: #374151;">
|
83 |
+
<h3 style="margin-top: 0; font-size: 1em; color: #111827;">{title}</h3>
|
84 |
<div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
|
85 |
</div>
|
86 |
+
'''
|
lcs.py
CHANGED
@@ -40,7 +40,7 @@ def find_common_subsequences(sentence, str_list):
|
|
40 |
return indexed_common_grams
|
41 |
|
42 |
# Example usage
|
43 |
-
sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
|
44 |
-
str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
|
45 |
|
46 |
-
print(find_common_subsequences(sentence, str_list))
|
|
|
40 |
return indexed_common_grams
|
41 |
|
42 |
# Example usage
|
43 |
+
# sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
|
44 |
+
# str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
|
45 |
|
46 |
+
# print(find_common_subsequences(sentence, str_list))
|
masking_methods.py
CHANGED
@@ -1,3 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
2 |
from transformers import pipeline
|
3 |
import random
|
@@ -10,21 +73,27 @@ def mask_non_stopword(sentence):
|
|
10 |
words = sentence.split()
|
11 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
12 |
if not non_stop_words:
|
13 |
-
return sentence
|
14 |
word_to_mask = random.choice(non_stop_words)
|
15 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
def mask_non_stopword_pseudorandom(sentence):
|
19 |
stop_words = set(stopwords.words('english'))
|
20 |
words = sentence.split()
|
21 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
22 |
if not non_stop_words:
|
23 |
-
return sentence
|
24 |
random.seed(10)
|
25 |
word_to_mask = random.choice(non_stop_words)
|
26 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
def high_entropy_words(sentence, non_melting_points):
|
30 |
stop_words = set(stopwords.words('english'))
|
@@ -37,10 +106,11 @@ def high_entropy_words(sentence, non_melting_points):
|
|
37 |
candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
38 |
|
39 |
if not candidate_words:
|
40 |
-
return sentence
|
41 |
|
42 |
max_entropy = -float('inf')
|
43 |
max_entropy_word = None
|
|
|
44 |
|
45 |
for word in candidate_words:
|
46 |
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
@@ -52,17 +122,19 @@ def high_entropy_words(sentence, non_melting_points):
|
|
52 |
if entropy > max_entropy:
|
53 |
max_entropy = entropy
|
54 |
max_entropy_word = word
|
|
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
# Load tokenizer and model for masked language model
|
60 |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
61 |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
62 |
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
1 |
+
# from transformers import AutoTokenizer, AutoModelForMaskedLM
|
2 |
+
# from transformers import pipeline
|
3 |
+
# import random
|
4 |
+
# from nltk.corpus import stopwords
|
5 |
+
# import math
|
6 |
+
|
7 |
+
# # Masking Model
|
8 |
+
# def mask_non_stopword(sentence):
|
9 |
+
# stop_words = set(stopwords.words('english'))
|
10 |
+
# words = sentence.split()
|
11 |
+
# non_stop_words = [word for word in words if word.lower() not in stop_words]
|
12 |
+
# if not non_stop_words:
|
13 |
+
# return sentence
|
14 |
+
# word_to_mask = random.choice(non_stop_words)
|
15 |
+
# masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
16 |
+
# return masked_sentence
|
17 |
+
|
18 |
+
# def mask_non_stopword_pseudorandom(sentence):
|
19 |
+
# stop_words = set(stopwords.words('english'))
|
20 |
+
# words = sentence.split()
|
21 |
+
# non_stop_words = [word for word in words if word.lower() not in stop_words]
|
22 |
+
# if not non_stop_words:
|
23 |
+
# return sentence
|
24 |
+
# random.seed(10)
|
25 |
+
# word_to_mask = random.choice(non_stop_words)
|
26 |
+
# masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
27 |
+
# return masked_sentence
|
28 |
+
|
29 |
+
# def high_entropy_words(sentence, non_melting_points):
|
30 |
+
# stop_words = set(stopwords.words('english'))
|
31 |
+
# words = sentence.split()
|
32 |
+
|
33 |
+
# non_melting_words = set()
|
34 |
+
# for _, point in non_melting_points:
|
35 |
+
# non_melting_words.update(point.lower().split())
|
36 |
+
|
37 |
+
# candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
38 |
+
|
39 |
+
# if not candidate_words:
|
40 |
+
# return sentence
|
41 |
+
|
42 |
+
# max_entropy = -float('inf')
|
43 |
+
# max_entropy_word = None
|
44 |
+
|
45 |
+
# for word in candidate_words:
|
46 |
+
# masked_sentence = sentence.replace(word, '[MASK]', 1)
|
47 |
+
# predictions = fill_mask(masked_sentence)
|
48 |
+
|
49 |
+
# # Calculate entropy based on top 5 predictions
|
50 |
+
# entropy = -sum(pred['score'] * math.log(pred['score']) for pred in predictions[:5])
|
51 |
+
|
52 |
+
# if entropy > max_entropy:
|
53 |
+
# max_entropy = entropy
|
54 |
+
# max_entropy_word = word
|
55 |
+
|
56 |
+
# return sentence.replace(max_entropy_word, '[MASK]', 1)
|
57 |
+
|
58 |
+
|
59 |
+
# # Load tokenizer and model for masked language model
|
60 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
61 |
+
# model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
62 |
+
# fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
63 |
+
|
64 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
65 |
from transformers import pipeline
|
66 |
import random
|
|
|
73 |
words = sentence.split()
|
74 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
75 |
if not non_stop_words:
|
76 |
+
return sentence, None, None
|
77 |
word_to_mask = random.choice(non_stop_words)
|
78 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
79 |
+
predictions = fill_mask(masked_sentence)
|
80 |
+
words = [pred['score'] for pred in predictions]
|
81 |
+
logits = [pred['token_str'] for pred in predictions]
|
82 |
+
return masked_sentence, words, logits
|
83 |
|
84 |
def mask_non_stopword_pseudorandom(sentence):
|
85 |
stop_words = set(stopwords.words('english'))
|
86 |
words = sentence.split()
|
87 |
non_stop_words = [word for word in words if word.lower() not in stop_words]
|
88 |
if not non_stop_words:
|
89 |
+
return sentence, None, None
|
90 |
random.seed(10)
|
91 |
word_to_mask = random.choice(non_stop_words)
|
92 |
masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
|
93 |
+
predictions = fill_mask(masked_sentence)
|
94 |
+
words = [pred['score'] for pred in predictions]
|
95 |
+
logits = [pred['token_str'] for pred in predictions]
|
96 |
+
return masked_sentence, words, logits
|
97 |
|
98 |
def high_entropy_words(sentence, non_melting_points):
|
99 |
stop_words = set(stopwords.words('english'))
|
|
|
106 |
candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
|
107 |
|
108 |
if not candidate_words:
|
109 |
+
return sentence, None, None
|
110 |
|
111 |
max_entropy = -float('inf')
|
112 |
max_entropy_word = None
|
113 |
+
max_logits = None
|
114 |
|
115 |
for word in candidate_words:
|
116 |
masked_sentence = sentence.replace(word, '[MASK]', 1)
|
|
|
122 |
if entropy > max_entropy:
|
123 |
max_entropy = entropy
|
124 |
max_entropy_word = word
|
125 |
+
max_logits = [pred['score'] for pred in predictions]
|
126 |
|
127 |
+
masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
|
128 |
+
words = [pred['score'] for pred in predictions]
|
129 |
+
logits = [pred['token_str'] for pred in predictions]
|
130 |
+
return masked_sentence, words, logits
|
131 |
|
132 |
# Load tokenizer and model for masked language model
|
133 |
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
|
134 |
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
|
135 |
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
136 |
|
137 |
+
non_melting_points = [(1, 'Jewish'), (2, 'messages'), (3, 'stab')]
|
138 |
+
a, b, c = high_entropy_words("A former Cornell University student was sentenced to 21 months in prison on Monday after admitting that he had posted a series of online messages last fall in which he threatened to stab, rape and behead Jewish people", non_melting_points)
|
139 |
+
print(f"logits type: {type(b)}")
|
140 |
+
print(f"logits content: {b}")
|
|
paraphraser.py
CHANGED
@@ -28,4 +28,4 @@ def generate_paraphrase(question):
|
|
28 |
res = paraphrase(question, para_tokenizer, para_model)
|
29 |
return res
|
30 |
|
31 |
-
print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
|
|
|
28 |
res = paraphrase(question, para_tokenizer, para_model)
|
29 |
return res
|
30 |
|
31 |
+
# print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
|
sampling_methods.py
CHANGED
@@ -1,145 +1,33 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
sentence = sentence.replace(word_to_mark, colored(word_to_mark, 'red'))
|
27 |
-
|
28 |
-
for word in common_words:
|
29 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
30 |
-
|
31 |
-
results.append({
|
32 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
33 |
-
"Common Substrings": common_substrings
|
34 |
-
})
|
35 |
-
return results
|
36 |
-
|
37 |
-
# Function for Inverse Transform Sampling
|
38 |
-
def inverse_transform_sampling(original_sentence, paraphrased_sentences):
|
39 |
-
stop_words = set(stopwords.words('english'))
|
40 |
-
original_sentence_lower = original_sentence.lower()
|
41 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
42 |
-
paraphrased_sentences_no_stopwords = []
|
43 |
-
|
44 |
-
for sentence in paraphrased_sentences_lower:
|
45 |
-
words = re.findall(r'\b\w+\b', sentence)
|
46 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
47 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
48 |
-
|
49 |
-
results = []
|
50 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
51 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
52 |
-
common_substrings = ', '.join(sorted(common_words))
|
53 |
-
|
54 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
55 |
-
if words_to_replace:
|
56 |
-
probabilities = [1 / len(words_to_replace)] * len(words_to_replace)
|
57 |
-
chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
|
58 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'magenta'))
|
59 |
-
|
60 |
-
for word in common_words:
|
61 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
62 |
-
|
63 |
-
results.append({
|
64 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
65 |
-
"Common Substrings": common_substrings
|
66 |
-
})
|
67 |
-
return results
|
68 |
-
|
69 |
-
# Function for Contextual Sampling
|
70 |
-
def contextual_sampling(original_sentence, paraphrased_sentences):
|
71 |
-
stop_words = set(stopwords.words('english'))
|
72 |
-
original_sentence_lower = original_sentence.lower()
|
73 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
74 |
-
paraphrased_sentences_no_stopwords = []
|
75 |
-
|
76 |
-
for sentence in paraphrased_sentences_lower:
|
77 |
-
words = re.findall(r'\b\w+\b', sentence)
|
78 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
79 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
80 |
-
|
81 |
-
results = []
|
82 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
83 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
84 |
-
common_substrings = ', '.join(sorted(common_words))
|
85 |
-
|
86 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
87 |
-
if words_to_replace:
|
88 |
-
context = " ".join([word for word in sentence.split() if word not in common_words])
|
89 |
-
chosen_word = random.choice(words_to_replace)
|
90 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
|
91 |
-
|
92 |
-
for word in common_words:
|
93 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
94 |
-
|
95 |
-
results.append({
|
96 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
97 |
-
"Common Substrings": common_substrings
|
98 |
-
})
|
99 |
-
return results
|
100 |
-
|
101 |
-
# Function for Exponential Minimum Sampling
|
102 |
-
def exponential_minimum_sampling(original_sentence, paraphrased_sentences):
|
103 |
-
stop_words = set(stopwords.words('english'))
|
104 |
-
original_sentence_lower = original_sentence.lower()
|
105 |
-
paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
|
106 |
-
paraphrased_sentences_no_stopwords = []
|
107 |
-
|
108 |
-
for sentence in paraphrased_sentences_lower:
|
109 |
-
words = re.findall(r'\b\w+\b', sentence)
|
110 |
-
filtered_sentence = ' '.join([word for word in words if word not in stop_words])
|
111 |
-
paraphrased_sentences_no_stopwords.append(filtered_sentence)
|
112 |
-
|
113 |
-
results = []
|
114 |
-
for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
|
115 |
-
common_words = set(original_sentence_lower.split()) & set(sentence.split())
|
116 |
-
common_substrings = ', '.join(sorted(common_words))
|
117 |
-
|
118 |
-
words_to_replace = [word for word in sentence.split() if word not in common_words]
|
119 |
-
if words_to_replace:
|
120 |
-
num_words = len(words_to_replace)
|
121 |
-
probabilities = [2 ** (-i) for i in range(num_words)]
|
122 |
-
chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
|
123 |
-
sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
|
124 |
-
|
125 |
-
for word in common_words:
|
126 |
-
sentence = sentence.replace(word, colored(word, 'green'))
|
127 |
-
|
128 |
-
results.append({
|
129 |
-
f"Paraphrased Sentence {idx+1}": sentence,
|
130 |
-
"Common Substrings": common_substrings
|
131 |
-
})
|
132 |
-
return results
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
#---------------------------------------------------------------------------
|
137 |
-
# aryans implementation please refactor it as you see fit
|
138 |
|
139 |
import torch
|
140 |
import random
|
141 |
|
142 |
-
def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
143 |
if sampling_technique == 'inverse_transform':
|
144 |
probs = torch.softmax(torch.tensor(logits), dim=-1)
|
145 |
cumulative_probs = torch.cumsum(probs, dim=-1)
|
@@ -160,4 +48,8 @@ def sample_word(words, logits, sampling_technique='inverse_transform', temperatu
|
|
160 |
raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
161 |
|
162 |
sampled_word = words[sampled_index]
|
163 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torch
|
2 |
+
# import random
|
3 |
+
|
4 |
+
# def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
5 |
+
# if sampling_technique == 'inverse_transform':
|
6 |
+
# probs = torch.softmax(torch.tensor(logits), dim=-1)
|
7 |
+
# cumulative_probs = torch.cumsum(probs, dim=-1)
|
8 |
+
# random_prob = random.random()
|
9 |
+
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0]
|
10 |
+
# elif sampling_technique == 'exponential_minimum':
|
11 |
+
# probs = torch.softmax(torch.tensor(logits), dim=-1)
|
12 |
+
# exp_probs = torch.exp(-torch.log(probs))
|
13 |
+
# random_probs = torch.rand_like(exp_probs)
|
14 |
+
# sampled_index = torch.argmax(random_probs * exp_probs)
|
15 |
+
# elif sampling_technique == 'temperature':
|
16 |
+
# scaled_logits = torch.tensor(logits) / temperature
|
17 |
+
# probs = torch.softmax(scaled_logits, dim=-1)
|
18 |
+
# sampled_index = torch.multinomial(probs, 1).item()
|
19 |
+
# elif sampling_technique == 'greedy':
|
20 |
+
# sampled_index = torch.argmax(torch.tensor(logits)).item()
|
21 |
+
# else:
|
22 |
+
# raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
23 |
+
|
24 |
+
# sampled_word = words[sampled_index]
|
25 |
+
# return sampled_word
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
import torch
|
28 |
import random
|
29 |
|
30 |
+
def sample_word(sentence, words, logits, sampling_technique='inverse_transform', temperature=1.0):
|
31 |
if sampling_technique == 'inverse_transform':
|
32 |
probs = torch.softmax(torch.tensor(logits), dim=-1)
|
33 |
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
|
48 |
raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
|
49 |
|
50 |
sampled_word = words[sampled_index]
|
51 |
+
|
52 |
+
# Replace [MASK] with the sampled word
|
53 |
+
filled_sentence = sentence.replace('[MASK]', sampled_word)
|
54 |
+
|
55 |
+
return filled_sentence
|
tree.py
CHANGED
@@ -1,29 +1,31 @@
|
|
1 |
-
import plotly.
|
2 |
import textwrap
|
3 |
import re
|
4 |
from collections import defaultdict
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
def get_levels_and_edges(nodes):
|
28 |
levels = {}
|
29 |
edges = []
|
@@ -37,58 +39,99 @@ def generate_plot(original_sentence, selected_sentences):
|
|
37 |
if level == 1:
|
38 |
edges.append((root_node, i))
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
return levels, edges
|
48 |
|
49 |
# Get levels and dynamic edges
|
50 |
levels, edges = get_levels_and_edges(nodes)
|
51 |
-
max_level = max(levels.values())
|
52 |
|
53 |
# Calculate positions
|
54 |
positions = {}
|
55 |
-
|
56 |
for node, level in levels.items():
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
|
62 |
for node, level in levels.items():
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Create figure
|
67 |
fig = go.Figure()
|
68 |
|
69 |
# Add nodes to the figure
|
70 |
for i, node in enumerate(wrapped_nodes):
|
|
|
71 |
x, y = positions[i]
|
72 |
fig.add_trace(go.Scatter(
|
73 |
-
x=[x],
|
74 |
y=[y],
|
75 |
mode='markers',
|
76 |
marker=dict(size=10, color='blue'),
|
77 |
hoverinfo='none'
|
78 |
))
|
79 |
fig.add_annotation(
|
80 |
-
x
|
81 |
y=y,
|
82 |
-
text=
|
83 |
showarrow=False,
|
84 |
-
|
85 |
align="center",
|
86 |
-
font=dict(size=
|
87 |
bordercolor='black',
|
88 |
borderwidth=1,
|
89 |
-
borderpad=
|
90 |
bgcolor='white',
|
91 |
-
width=
|
92 |
)
|
93 |
|
94 |
# Add edges to the figure
|
@@ -96,19 +139,19 @@ def generate_plot(original_sentence, selected_sentences):
|
|
96 |
x0, y0 = positions[edge[0]]
|
97 |
x1, y1 = positions[edge[1]]
|
98 |
fig.add_trace(go.Scatter(
|
99 |
-
x=[x0, x1],
|
100 |
y=[y0, y1],
|
101 |
mode='lines',
|
102 |
-
line=dict(color='black', width=
|
103 |
))
|
104 |
|
105 |
fig.update_layout(
|
106 |
showlegend=False,
|
107 |
-
margin=dict(t=
|
108 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
109 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
110 |
-
width=
|
111 |
-
height=
|
112 |
)
|
113 |
|
114 |
return fig
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
import textwrap
|
3 |
import re
|
4 |
from collections import defaultdict
|
5 |
+
|
6 |
+
def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
|
7 |
+
# Combine nodes into one list with appropriate labels
|
8 |
+
nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
|
9 |
+
nodes[0] += ' L0' # Paraphrased sentence is level 0
|
10 |
+
para_len = len(scheme_sentences)
|
11 |
+
for i in range(1, para_len + 1):
|
12 |
+
nodes[i] += ' L1' # Scheme sentences are level 1
|
13 |
+
for i in range(para_len + 1, len(nodes)):
|
14 |
+
nodes[i] += ' L2' # Sampled sentences are level 2
|
15 |
+
|
16 |
+
# Define the highlight_words function
|
17 |
+
def highlight_words(sentence, color_map):
|
18 |
+
for word, color in color_map.items():
|
19 |
+
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
|
20 |
+
return sentence
|
21 |
+
|
22 |
+
# Clean and wrap nodes, and highlight specified words globally
|
|
|
23 |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
|
24 |
+
global_color_map = dict(highlight_info)
|
25 |
+
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
|
26 |
+
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
|
27 |
+
|
28 |
+
# Function to determine tree levels and create edges dynamically
|
29 |
def get_levels_and_edges(nodes):
|
30 |
levels = {}
|
31 |
edges = []
|
|
|
39 |
if level == 1:
|
40 |
edges.append((root_node, i))
|
41 |
|
42 |
+
# Add edges from each L1 node to their corresponding L2 nodes
|
43 |
+
l1_indices = [i for i, level in levels.items() if level == 1]
|
44 |
+
l2_indices = [i for i, level in levels.items() if level == 2]
|
45 |
+
|
46 |
+
for i, l1_node in enumerate(l1_indices):
|
47 |
+
l2_start = i * 4
|
48 |
+
for j in range(4):
|
49 |
+
l2_index = l2_start + j
|
50 |
+
if l2_index < len(l2_indices):
|
51 |
+
edges.append((l1_node, l2_indices[l2_index]))
|
52 |
+
|
53 |
+
# Add edges from each L2 node to their corresponding L3 nodes
|
54 |
+
l2_indices = [i for i, level in levels.items() if level == 2]
|
55 |
+
l3_indices = [i for i, level in levels.items() if level == 3]
|
56 |
+
|
57 |
+
l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
|
58 |
+
|
59 |
+
# Map L3 nodes to L2 nodes
|
60 |
+
for l3_node in l3_indices:
|
61 |
+
l2_node = l3_node % len(l2_indices)
|
62 |
+
l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
|
63 |
+
|
64 |
+
for l2_node, l3_nodes in l2_to_l3_map.items():
|
65 |
+
for l3_node in l3_nodes:
|
66 |
+
edges.append((l2_node, l3_node))
|
67 |
|
68 |
return levels, edges
|
69 |
|
70 |
# Get levels and dynamic edges
|
71 |
levels, edges = get_levels_and_edges(nodes)
|
72 |
+
max_level = max(levels.values(), default=0)
|
73 |
|
74 |
# Calculate positions
|
75 |
positions = {}
|
76 |
+
level_heights = defaultdict(int)
|
77 |
for node, level in levels.items():
|
78 |
+
level_heights[level] += 1
|
79 |
|
80 |
+
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
|
81 |
+
x_gap = 2
|
82 |
+
l1_y_gap = 10
|
83 |
+
l2_y_gap = 6
|
84 |
|
85 |
for node, level in levels.items():
|
86 |
+
if level == 1:
|
87 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
|
88 |
+
elif level == 2:
|
89 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
|
90 |
+
else:
|
91 |
+
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
|
92 |
+
y_offsets[level] += 1
|
93 |
+
|
94 |
+
# Function to highlight words in a wrapped node string
|
95 |
+
def color_highlighted_words(node, color_map):
|
96 |
+
parts = re.split(r'(\{\{.*?\}\})', node)
|
97 |
+
colored_parts = []
|
98 |
+
for part in parts:
|
99 |
+
match = re.match(r'\{\{(.*?)\}\}', part)
|
100 |
+
if match:
|
101 |
+
word = match.group(1)
|
102 |
+
color = color_map.get(word, 'black')
|
103 |
+
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
|
104 |
+
else:
|
105 |
+
colored_parts.append(part)
|
106 |
+
return ''.join(colored_parts)
|
107 |
|
108 |
# Create figure
|
109 |
fig = go.Figure()
|
110 |
|
111 |
# Add nodes to the figure
|
112 |
for i, node in enumerate(wrapped_nodes):
|
113 |
+
colored_node = color_highlighted_words(node, global_color_map)
|
114 |
x, y = positions[i]
|
115 |
fig.add_trace(go.Scatter(
|
116 |
+
x=[-x], # Reflect the x coordinate
|
117 |
y=[y],
|
118 |
mode='markers',
|
119 |
marker=dict(size=10, color='blue'),
|
120 |
hoverinfo='none'
|
121 |
))
|
122 |
fig.add_annotation(
|
123 |
+
x=-x, # Reflect the x coordinate
|
124 |
y=y,
|
125 |
+
text=colored_node,
|
126 |
showarrow=False,
|
127 |
+
xshift=15,
|
128 |
align="center",
|
129 |
+
font=dict(size=8),
|
130 |
bordercolor='black',
|
131 |
borderwidth=1,
|
132 |
+
borderpad=2,
|
133 |
bgcolor='white',
|
134 |
+
width=150
|
135 |
)
|
136 |
|
137 |
# Add edges to the figure
|
|
|
139 |
x0, y0 = positions[edge[0]]
|
140 |
x1, y1 = positions[edge[1]]
|
141 |
fig.add_trace(go.Scatter(
|
142 |
+
x=[-x0, -x1], # Reflect the x coordinates
|
143 |
y=[y0, y1],
|
144 |
mode='lines',
|
145 |
+
line=dict(color='black', width=1)
|
146 |
))
|
147 |
|
148 |
fig.update_layout(
|
149 |
showlegend=False,
|
150 |
+
margin=dict(t=20, b=20, l=20, r=20),
|
151 |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
152 |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
153 |
+
width=1200, # Adjusted width to accommodate more levels
|
154 |
+
height=1000 # Adjusted height to accommodate more levels
|
155 |
)
|
156 |
|
157 |
return fig
|