import nltk nltk.download('stopwords') import plotly.graph_objs as go from transformers import pipeline import random import gradio as gr from tree import generate_subplot1, generate_subplot2 from paraphraser import generate_paraphrase from lcs import find_common_subsequences, find_common_gram_positions from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html from entailment import analyze_entailment from masking_methods import mask_non_stopword, high_entropy_words from sampling_methods import sample_word from detectability import SentenceDetectabilityCalculator from distortion import SentenceDistortionCalculator from euclidean_distance import SentenceEuclideanDistanceCalculator from threeD_plot import gen_three_D_plot # Function for the Gradio interface def model(prompt): user_prompt = prompt paraphrased_sentences = generate_paraphrase(user_prompt) analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) common_grams = find_common_subsequences(user_prompt, selected_sentences) subsequences = [subseq for _, subseq in common_grams] common_grams_position = find_common_gram_positions(selected_sentences, subsequences) # Create masked results using a single loop masked_results = [] for sentence in paraphrased_sentences: masked_results.extend([ (mask_non_stopword, sentence), (mask_non_stopword, sentence, True), (high_entropy_words, sentence, common_grams) ]) # Process masking functions and unpack results masked_outputs = [ (func(sent) if len(result) == 2 else func(sent, extra)) for func, sent, *extra in masked_results for result in [func(sent, *extra)] ] # Unpack masked outputs into separate lists masked_sentences, masked_words, masked_logits = zip(*masked_outputs) if masked_outputs else ([], [], []) sampled_sentences = [] for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits): for technique in ['inverse_transform', 'exponential_minimum', 'temperature', 'greedy']: sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique=technique, temperature=1.0)) colors = ["red", "blue", "brown", "green"] def select_color(): return random.choice(colors) highlight_info = [(word, select_color()) for _, word in common_grams] highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt") highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences") highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences") trees1, trees2 = [], [] for i, sentence in enumerate(paraphrased_sentences): next_masked_sentences = masked_sentences[i * 3:(i + 1) * 3] next_sampled_sentences = sampled_sentences[i * 12:(i + 1) * 12] tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams) trees1.append(tree1) tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams) trees2.append(tree2) reparaphrased_sentences = generate_paraphrase(sampled_sentences) # Process the sentences in batches of 10 reparaphrased_sentences_list = [] for i in range(0, len(reparaphrased_sentences), 10): batch = reparaphrased_sentences[i:i + 10] if len(batch) == 10: html_block = reparaphrased_sentences_html(batch) reparaphrased_sentences_list.append(html_block) # Calculate metrics distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences) distortion_calculator.calculate_all_metrics() distortion_calculator.normalize_metrics() distortion = distortion_calculator.get_combined_distortions() distortion_list = list(distortion.values()) detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences) detectability_calculator.calculate_all_metrics() detectability_calculator.normalize_metrics() detectability = detectability_calculator.get_combined_detectabilities() detectability_list = list(detectability.values()) euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences) euclidean_dist_calculator.calculate_all_metrics() euclidean_dist_calculator.normalize_metrics() euclidean_dist = euclidean_dist_calculator.get_normalized_metrics() euclidean_dist_list = list(euclidean_dist.values()) three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list) return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot] # Gradio Interface with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# **AIISC Watermarking Model**") with gr.Row(): user_input = gr.Textbox(label="User Prompt") with gr.Row(): submit_button = gr.Button("Submit") clear_button = gr.Button("Clear") with gr.Row(): highlighted_user_prompt = gr.HTML() with gr.Row(): with gr.Tabs(): with gr.TabItem("Paraphrased Sentences"): highlighted_accepted_sentences = gr.HTML() with gr.TabItem("Discarded Sentences"): highlighted_discarded_sentences = gr.HTML() with gr.Row(): gr.Markdown("### Where to Watermark?") # Label for masked sentences trees with gr.Row(): with gr.Tabs(): tree1_tabs = [gr.Plot() for _ in range(10)] # Adjust this range according to the number of trees for i, tree1 in enumerate(tree1_tabs): with gr.TabItem(f"Sentence {i + 1}"): pass # Placeholder for each tree plot with gr.Row(): gr.Markdown("### How to Watermark?") # Label for sampled sentences trees with gr.Row(): with gr.Tabs(): tree2_tabs = [gr.Plot() for _ in range(10)] # Adjust this range according to the number of trees for i, tree2 in enumerate(tree2_tabs): with gr.TabItem(f"Sentence {i + 1}"): pass # Placeholder for each tree plot with gr.Row(): gr.Markdown("### Re-paraphrased Sentences") # Label for re-paraphrased sentences with gr.Row(): with gr.Tabs(): reparaphrased_sentences_tabs = [gr.HTML() for _ in range(120)] # 120 tabs for 120 batches of sentences for i, reparaphrased_sent_html in enumerate(reparaphrased_sentences_tabs): with gr.TabItem(f"Sentence {i + 1}"): pass # Placeholder for each batch with gr.Row(): gr.Markdown("### 3D Plot for Sweet Spot") with gr.Row(): three_D_plot = gr.Plot() submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]) clear_button.click(lambda: "", inputs=None, outputs=user_input) clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]) demo.launch(share=True)