import numpy as np
from sklearn.decomposition import PCA
import gensim.downloader as api
import gradio as gr
import plotly.graph_objects as go

# Load the Word2Vec model
model = api.load("word2vec-google-news-300")


def gensim_analogy(model, word1, word2, word3):
    try:
        result = model.most_similar(positive=[word2, word3], negative=[word1], topn=1)
        return result[0][0]  # Return the word
    except KeyError as e:
        return str(e)


def plot_words_plotly(model, words):
    vectors = np.array([model[word] for word in words if word in model.key_to_index])

    # Reduce dimensions to 2D for plotting
    pca = PCA(n_components=2)
    vectors_2d = pca.fit_transform(vectors)

    # Create a scatter plot
    fig = go.Figure()

    # Add scatter points for each word vector
    for word, vec in zip(words, vectors_2d):
        fig.add_trace(go.Scatter(x=[vec[0]], y=[vec[1]],
                                 text=[word], mode='markers+text',
                                 textposition="bottom center",
                                 name=word))

    fig.update_layout(title="Visualization of Word Vectors",
                      xaxis_title="PCA 1",
                      yaxis_title="PCA 2",
                      showlegend=True,
                      width=600,  # Adjust width as needed
                      height=400)  # Adjust height as needed

    return fig


def gradio_interface(choice, custom_input):
    if choice == "Custom":
        if not custom_input or len(custom_input.split(", ")) != 3:
            return "Invalid input. Please enter exactly three words, separated by commas.", None, {
                "error": "Invalid input"}
        words = custom_input.split(", ")
    else:
        if not choice:
            return "Invalid input. Please select or enter words.", None, {
                "error": "Invalid input"}
        words = choice.split(", ")

    word1, word2, word3 = words
    word4 = gensim_analogy(model, word1, word2, word3)
    plot_fig = plot_words_plotly(model, [word1, word2, word3, word4])

    if word4 in model.key_to_index:
        vector = model[word4]
        vector_display = f"{word4}: {np.round(vector, 2).tolist()}"
    else:
        vector_display = "Vector not available for the resulting word"

    return word4, plot_fig, vector_display


choices = [
    "man, king, woman",
    "Paris, France, London",
    "strong, stronger, weak",
    "pork, pig, beef",
    "Custom"
]


def clear_inputs():
    return "", "", "", "", None


# Define the layout using Rows and Columns
with gr.Blocks() as iface:
    with gr.Row():
        with gr.Column():
            gr.Markdown("# Word Analogy and Vector Visualization")
            gr.Markdown(
                "Select a predefined triplet of words or choose 'Custom' and enter your own (comma-separated) to find a fourth word by analogy, and see their vectors plotted with Plotly.")

            radio = gr.Radio(choices=choices, label="Choose predefined words or enter custom words")

            custom_words = gr.Textbox(
                label="Custom words (comma-separated, required for custom choice; use only if 'Custom' is selected)",
                placeholder="Enter 3 words separated by commas")

            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit")

            output_word = gr.Textbox(label="Output Word")

        word_plot = gr.Plot(label="Word Vectors Visualization")

    with gr.Row():
        word_vectorization = gr.Textbox(label="Vectorization of the Output Word", lines=4, max_lines=4)

    clear_btn.click(fn=clear_inputs, inputs=None,
                    outputs=[radio, custom_words, output_word, word_vectorization, word_plot])
    submit_btn.click(fn=gradio_interface, inputs=[radio, custom_words],
                     outputs=[output_word, word_plot, word_vectorization])

iface.launch(share=True)