import nltk
nltk.download('stopwords')
nltk.download('punkt_tab')
# from transformers import AutoTokenizer
# from transformers import AutoModelForSeq2SeqLM
import plotly.graph_objs as go
from transformers import pipeline
import random
import gradio as gr
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
from lcs import find_common_subsequences, find_common_gram_positions
from highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
from detectability import SentenceDetectabilityCalculator
from distortion import SentenceDistortionCalculator
from euclidean_distance import SentenceEuclideanDistanceCalculator
from threeD_plot import gen_three_D_plot


# Function for the Gradio interface
def model(prompt):
    user_prompt = prompt
    paraphrased_sentences = generate_paraphrase(user_prompt)
    analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
    print(analyze_entailment(user_prompt, paraphrased_sentences, 0.7))
    common_grams = find_common_subsequences(user_prompt, selected_sentences)
    subsequences = [subseq for _, subseq in common_grams]
    common_grams_position = find_common_gram_positions(selected_sentences, subsequences)

    masked_sentences = []
    masked_words = []
    masked_logits = []

    for sentence in paraphrased_sentences:
        masked_sent, logits, words = mask_non_stopword(sentence)
        masked_sentences.append(masked_sent)
        masked_words.append(words)
        masked_logits.append(logits)
        
        masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
        masked_sentences.append(masked_sent)
        masked_words.append(words)
        masked_logits.append(logits)
        
        masked_sent, logits, words = high_entropy_words(sentence, common_grams)
        masked_sentences.append(masked_sent)
        masked_words.append(words)
        masked_logits.append(logits)

    sampled_sentences = []
    for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
        sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
        sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
        sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
        sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))




    colors = ["red", "blue", "brown", "green"]

    def select_color():
        return random.choice(colors)

    highlight_info = [(word, select_color()) for _, word in common_grams]

    highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
    highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
    highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")

    trees1 = []
    trees2 = []

    masked_index = 0
    sampled_index = 0

    for i, sentence in enumerate(paraphrased_sentences):
        next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
        next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]

        tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
        trees1.append(tree1)

        tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
        trees2.append(tree2)

        masked_index += 3 
        sampled_index += 12

    reparaphrased_sentences = generate_paraphrase(sampled_sentences)

    len_reparaphrased_sentences = len(reparaphrased_sentences)

    reparaphrased_sentences_list = []

    # Process the sentences in batches of 10
    for i in range(0, len_reparaphrased_sentences, 10):
        # Get the current batch of 10 sentences
        batch = reparaphrased_sentences[i:i + 10]
        
        # Check if the batch has exactly 10 sentences
        if len(batch) == 10:
            # Call the display_sentences function and store the result in the list
            html_block = reparaphrased_sentences_html(batch)
            reparaphrased_sentences_list.append(html_block)

    distortion_list = []
    detectability_list = []
    euclidean_dist_list = []

    distortion_calculator = SentenceDistortionCalculator(user_prompt, reparaphrased_sentences)
    distortion_calculator.calculate_all_metrics()
    distortion_calculator.normalize_metrics()
    distortion_calculator.calculate_combined_distortion()

    distortion = distortion_calculator.get_combined_distortions()

    for each in distortion.items():
        distortion_list.append(each[1])

    detectability_calculator = SentenceDetectabilityCalculator(user_prompt, reparaphrased_sentences)
    detectability_calculator.calculate_all_metrics()
    detectability_calculator.normalize_metrics()
    detectability_calculator.calculate_combined_detectability()

    detectability = detectability_calculator.get_combined_detectabilities()

    for each in detectability.items():
        detectability_list.append(each[1])

    euclidean_dist_calculator = SentenceEuclideanDistanceCalculator(user_prompt, reparaphrased_sentences)
    euclidean_dist_calculator.calculate_all_metrics()
    euclidean_dist_calculator.normalize_metrics()
    euclidean_dist_calculator.get_normalized_metrics()

    euclidean_dist = detectability_calculator.get_combined_detectabilities()

    for each in euclidean_dist.items():
        euclidean_dist_list.append(each[1])

    three_D_plot = gen_three_D_plot(detectability_list, distortion_list, euclidean_dist_list)

    return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 + reparaphrased_sentences_list + [three_D_plot]

# Logic for the new "Paraphrase and Discarded Sentence Generator" button
def generate_paraphrase_and_discarded_sentences(prompt):
    user_prompt = prompt
    paraphrased_sentences = generate_paraphrase(user_prompt)
    analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
    
    # Combine discarded sentences with their entailment scores
    discarded_sentences_with_scores = [
        f"{sentence} (Entailment Score: {score:.2f})"
        for sentence, score in discarded_sentences.items()
    ]
    
    # Prepare paraphrased sentences for display
    paraphrased_sentences_html = highlight_common_words_dict([], selected_sentences, "Paraphrased Sentences")
    discarded_sentences_html = "<br>".join(discarded_sentences_with_scores)
    
    return paraphrased_sentences_html, discarded_sentences_html

with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# **AIISC Watermarking Model**")

    with gr.Row():
        user_input = gr.Textbox(label="User Prompt")

    with gr.Row():
        submit_button = gr.Button("Submit")
        clear_button = gr.Button("Clear")
        generate_non_melting_point_button = gr.Button("Generate Non-Melting Point")  # New button
        paraphrase_discard_button = gr.Button("Paraphrase and Discarded Sentence Generator")

    with gr.Row():
        highlighted_user_prompt = gr.HTML()

    with gr.Row():
        with gr.Tabs():
            with gr.TabItem("Paraphrased Sentences"):
                highlighted_accepted_sentences = gr.HTML()
            with gr.TabItem("Discarded Sentences"):
                highlighted_discarded_sentences = gr.HTML()

    # Adding labels before the tree plots
    with gr.Row():
        gr.Markdown("### Where to Watermark?")  # Label for masked sentences trees
    with gr.Row():
        with gr.Tabs():
            tree1_tabs = []
            for i in range(10):  # Adjust this range according to the number of trees
                with gr.TabItem(f"Sentence {i+1}"):
                    tree1 = gr.Plot()
                    tree1_tabs.append(tree1)

    with gr.Row():
        gr.Markdown("### How to Watermark?")  # Label for sampled sentences trees
    with gr.Row():
        with gr.Tabs():
            tree2_tabs = []
            for i in range(10):  # Adjust this range according to the number of trees
                with gr.TabItem(f"Sentence {i+1}"):
                    tree2 = gr.Plot()
                    tree2_tabs.append(tree2)

    # Adding the "Re-paraphrased Sentences" section
    with gr.Row():
        gr.Markdown("### Re-paraphrased Sentences")  # Label for re-paraphrased sentences

    # Adding tabs for the re-paraphrased sentences
    with gr.Row():
        with gr.Tabs():
            reparaphrased_sentences_tabs = []
            for i in range(120):  # 120 tabs for 120 batches of sentences
                with gr.TabItem(f"Sentence {i+1}"):
                    reparaphrased_sent_html = gr.HTML()  # Placeholder for each batch
                    reparaphrased_sentences_tabs.append(reparaphrased_sent_html)

    with gr.Row():
        gr.Markdown("### 3D Plot for Sweet Spot")
    with gr.Row():
        three_D_plot = gr.Plot()

    # Logic for the new button
    def generate_non_melting_points_only(prompt):
        user_prompt = prompt
        paraphrased_sentences = generate_paraphrase(user_prompt)
        analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
        common_grams = find_common_subsequences(user_prompt, selected_sentences)
        highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
        return highlighted_user_prompt

    # Connect buttons to functions
    submit_button.click(
        model,
        inputs=user_input,
        outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot]
    )
    generate_non_melting_point_button.click(
        generate_non_melting_points_only,
        inputs=user_input,
        outputs=highlighted_user_prompt
    )

    paraphrase_discard_button.click(
        generate_paraphrase_and_discarded_sentences,
        inputs=user_input,
        outputs=[highlighted_accepted_sentences, highlighted_discarded_sentences]
    )
    clear_button.click(lambda: "", inputs=None, outputs=user_input)
    clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs + reparaphrased_sentences_tabs + [three_D_plot])

demo.launch(share=True)