import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
import torch
import arxiv


def main():
    id_provided = True

    st.set_page_config(
        layout="wide",
        initial_sidebar_state="auto",
        page_title="Title Generator!",
        page_icon=None,
    )

    st.title("Title Generator: Generate a title from the abstract of a paper")
    st.text("")
    st.text("")

    example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""",
)
    # st.selectbox("Provide the link/id for an arxiv paper", example_prompts)

    # Take the message which needs to be processed
    message = st.text_area("...or paste a paper's abstract to generate a title")
    if len(message)<1:
        message=example
        id_provided = True
        ids = message.split('/')[-1]
        search = arxiv.Search(id_list=[ids])
        for result in search.results():
            message = result.summary
            title = result.title
    else:
        id_provided = False

    st.text("")
    models_to_choose = [
        "AryanLala/autonlp-Scientific_Title_Generator-34558227",
        "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full"
    ]

    BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose)

    def preprocess(text):
        if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") |
          (BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")):
            return [text]
        else:
            st.error("Please select a model first")

    @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
    def load_model():
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL)
        return model, tokenizer

    def get_summary(text):
        with st.spinner(text="Processing your request"):
            model, tokenizer = load_model()
            preprocessed = preprocess(text)
            inputs = tokenizer(
                preprocessed, truncation=True, padding="longest", return_tensors="pt"
            )
            output = model.generate(
                **inputs,
                max_length=60,
                num_beams=10,
                num_return_sequences=1,
                temperature=1.5,
            )
            target_text = tokenizer.batch_decode(output, skip_special_tokens=True)
            return target_text[0]

    # Define function to run when submit is clicked
    def submit(message):
        if len(message) > 0:
            summary = get_summary(message)
            if id_provided:
                html_str = f"""
                <style>
                p.a {{
                font: 20px Courier;
                }}
                </style>
                <p class="a"><b>Title Generated:></b> {summary} </p> 
                <p class="a"><b>Original Title:></b> {title} </p>
                """
            else:
                html_str = f"""
                <style>
                p.a {{
                font: 20px Courier;
                }}
                </style>
                <p class="a"><b>Title Generated:></b> {summary} </p>
                """

            st.markdown(html_str, unsafe_allow_html=True)
            # st.markdown(emoji)
        else:
            st.error("The text can't be empty")

    # Run algo when submit button is clicked
    if st.button("Submit"):
        submit(message)

    with st.expander("Additional Information"):
        st.markdown("""
        The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset)
        The task of the models is to suggest an appropraite title from the abstract of a scientific paper.
        
        The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data
        from the Cs.AI (Artificial Intelligence) category of papers.

        The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full) 
        was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition)

        Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids
        """,unsafe_allow_html=True,)

    st.text('\n')
    st.text('\n')
    st.markdown(
    '''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89)
      </span>''', 
     unsafe_allow_html=True,
    )

if __name__ == "__main__":
    main()