import time
from about import show_about_ask2democracy
import streamlit as st
from pinecone_quieries import PineconeProposalQueries
from config import Config
from samples import *

queries = PineconeProposalQueries (index_name= Config.index_name,
                                        api_key = Config.es_password,
                                        environment = Config.pinecone_environment,
                                        embedding_dim = Config.embedding_dim,
                                        reader_name_or_path = Config.reader_model_name_or_path,
                                        use_gpu = Config.use_gpu,
                                        OPENAI_key= None)

def search(question, retriever_top_k, reader_top_k, selected_index=None):
    filters = {"source_title": selected_index}
    query_result = queries.search_by_query(query = question, 
                                                    retriever_top_k = retriever_top_k, 
                                                    reader_top_k = reader_top_k,
                                                    filters = filters)
    result = []
    for i in range(0, len(query_result)):
        item = query_result[i]
        result.append([[i+1], item.answer.replace("\n",""), item.context[:250], 
                    item.meta['title'], item.meta['source_title'], 
                    int(item.meta['page']), item.meta['source_url']])
    return result
    
def search_and_show_results(query:str, retriever_top_k = 5, reader_top_k =3, selected_index=None):
    stt = time.time()
    results = search(query, retriever_top_k=retriever_top_k, 
                     reader_top_k=reader_top_k, selected_index=selected_index)
    ent = time.time()
    elapsed_time = round(ent - stt, 2)

    st.write(f"**Resultados encontrados para la pregunta** \"{query}\" ({elapsed_time} sec.):")
    for i, answer in enumerate(results):
        st.subheader(f"{answer[1]}")
        doc = answer[2][:250] + "..."
        st.markdown(f"{doc}[Lee más aquí]({answer[6]})", unsafe_allow_html=True)
        st.caption(f"Fuente: {answer[4]} - Artículo: {answer[3]} - Página: {answer[5]}")
                
def search_and_generate_answer(question, retriever_top_k, generator_top_k, 
                               openai_api_key, openai_model_name= "text-davinci-003", 
                               temperature = .5, max_tokens = 30,  selected_index = None):
    filters = {"source_title": selected_index}
    
    query_result = queries.genenerate_answer_OpenAI(query = question, 
                                                    retriever_top_k = retriever_top_k, 
                                                    generator_top_k = generator_top_k,
                                                    filters = filters, OPENAI_key = openai_api_key,
                                                    openai_model_name= openai_model_name,temperature = temperature, max_tokens = max_tokens)
    result = []
    for i in range(0, len(query_result)):
        item = query_result[i]
        source_title = item.meta['doc_metas'][0]['source_title']
        source_url = item.meta['doc_metas'][0]['source_url']
        chapter_titles = [source['title'] for source in item.meta['doc_metas']]
        result.append([[i+1], item.answer.replace("\n",""), 
                    source_title, source_url, str(chapter_titles)])
    return result

def search_and_show_generative_results(query:str, retriever_top_k = 5, generator_top_k =1 , openai_api_key = None, openai_model_name = "text-davinci-003", temperature = .5, max_tokens = 30, selected_index = None):
    # set start time
    stt = time.time()
    results = search_and_generate_answer(query, retriever_top_k = retriever_top_k, 
                                         generator_top_k= generator_top_k, 
                                         openai_api_key = openai_api_key,
                                         openai_model_name= openai_model_name, 
                                         temperature = temperature, max_tokens = max_tokens, 
                                         selected_index = selected_index)
    ent = time.time()
    elapsed_time = round(ent - stt, 2)
    st.write(f"**Respuesta generada para la pregunta**  \"{query}\" ({elapsed_time} sec.):")
    if results != None:
        for i, answer in enumerate(results):
            # answer starts with header
            st.subheader(f"{answer[1]}")
            st.caption(f"Fuentes: {answer[2]} - {answer[4]}")
            st.markdown(f"[Lee más aquí]({answer[3]})")
        

indexes =  [{"title": "Propuesta reforma a la salud 13 de febrero de 2023", "name": "Reforma de la salud 13 Febrero 2023", "samples": samples_reforma_salud}, 
            {"title": "Propuesta reforma pensional marzo 22 de 2023", "name": "Reforma pensional Marzo 2023", "samples": samples_reforma_pensional},
            {"title": "Hallazgos de la comisión de la verdad", "name": "Hallazgos y recomendaciones - 28 de Junio 2022", "samples": samples_hallazgos_paz}
            ]  



index_titles = [item["title"] for item in indexes] 

def get_selected_index_by_title(title):
    for item in indexes:
        if item["title"] == title:
            return item["name"]
    return None

def get_samples_for_index(title):
    for item in indexes:
        if item["title"] == title:
            return item["samples"]
    return None

def main():
    st.title("Ask2Democracy 🇨🇴")
    st.markdown("""
    <div align="right">
    Creado por Jorge Henao 🇨🇴 <a href="https://twitter.com/jhenaotw" target='_blank'>Twitter</a> <a href="https://www.linkedin.com/in/henaojorge" target='_blank'>LinkedIn</a> <a href="https://linktr.ee/jorgehenao" target='_blank'>Linktree</a>
    </div>""", unsafe_allow_html=True)    
    
    # session_state = st.session_state
    # if "api_key" not in session_state:
    #     session_state.api_key = ""
    
    with st.form("my_form"):
        st.sidebar.title("Configuración de búsqueda")
        with st.sidebar.expander("Parámetros de recuperación", expanded= True):        
            index = st.selectbox("Selecciona el documento que deseas explorar", index_titles)
            top_k_retriever = st.slider("Retriever Top K", 1, 10, 5)
            top_k_reader = st.slider("Reader Top K", 1, 10, 3)
            
        with st.sidebar.expander("Configuración OpenAI"):        
            openai_api_key = st.text_input("API Key", type="password", placeholder="Copia aquí tu OpenAI API key (no será guardada)",
                            help="puedes obtener tu api key de OpenAI en https://platform.openai.com/account/api-keys.")
            openai_api_model = st.text_input("Modelo", value= "text-davinci-003")
            openai_api_temp = st.slider("Temperatura", 0.1, 1.0, 0.5, step=0.1)
            openai_api_max_tokens = st.slider("Max tokens", 10, 100, 60, step=10)
        
        # if openai_api_key:
        #     session_state.password = openai_api_key
            
        sample_questions = get_samples_for_index(index).splitlines()
        query = st.text_area("",placeholder="Escribe aquí tu pregunta, cuanto más contexto le des, mejor serán las respuestas")
        with st.expander("Algunas preguntas de ejemplo", expanded= False):
            for sample in sample_questions:
                st.markdown(f"- {sample}")

        submited = st.form_submit_button("Buscar") 
    if submited:
        selected_index = get_selected_index_by_title(index)
        if openai_api_key:
            with st.expander("", expanded= True):    
                search_and_show_generative_results(query = query,retriever_top_k= top_k_retriever,
                                                    generator_top_k= 1, openai_api_key = openai_api_key,
                                                    openai_model_name = openai_api_model, 
                                                    temperature= openai_api_temp, 
                                                    max_tokens= openai_api_max_tokens,
                                                    selected_index = selected_index)
        with st.expander("", expanded= True):    
            search_and_show_results(query, retriever_top_k=top_k_retriever, 
                                    reader_top_k=top_k_reader, 
                                    selected_index=selected_index)
    else:
        show_about_ask2democracy()

if __name__ == "__main__":
    main()