import json
import logging
import os
import sys
from threading import Lock

import gradio as gr
import s3fs
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import (ServiceContext, StorageContext,
                         load_index_from_storage, set_global_service_context)
from llama_index.agent import ContextRetrieverOpenAIAgent, OpenAIAgent
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.llms import ChatMessage, MessageRole, OpenAI
from llama_index.prompts import ChatPromptTemplate, PromptTemplate
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.retrievers import RecursiveRetriever
from llama_index.tools import QueryEngineTool, ToolMetadata
from llama_index.vector_stores import PGVectorStore
from sqlalchemy import make_url


def get_embed_model():
    model_kwargs = {'device': 'cpu'}
    if torch.cuda.is_available():
      model_kwargs['device'] = 'cuda'
    if torch.backends.mps.is_available():
      model_kwargs['device'] = 'mps'

    encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
    print("Loading model...")
    try:
      model_norm = HuggingFaceEmbeddings(
        model_name="thenlper/gte-small",
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
      )
    except Exception as exception:
      print(f"Model not found. Loading fake model...{exception}")
      exit()
    print("Model loaded.")
    return model_norm

embed_model = get_embed_model()
llm = OpenAI("gpt-4")
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)

s3 = s3fs.S3FileSystem(
  key=os.environ["AWS_CANONICAL_KEY"],
  secret=os.environ["AWS_CANONICAL_SECRET"],
)

titles = s3.ls("f150-user-manual/recursive-agent/")
titles = list(map(lambda x: x.split("/")[-1], titles))

agents = {}
for title in titles:
  if(title == "vector_index"):
    continue

  print(title)
  # build vector index
  storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/{title}/vector_index", fs=s3)
  vector_index = load_index_from_storage(storage_context)

  # define query engines
  vector_query_engine = vector_index.as_query_engine(
    similarity_top_k=2,
    verbose=True
  )
  agents[title] = vector_query_engine
print(f"Agents: {len(agents)}")
storage_context = StorageContext.from_defaults(persist_dir=f"f150-user-manual/recursive-agent/vector_index", fs=s3)
top_level_vector_index = load_index_from_storage(storage_context)
vector_retriever = top_level_vector_index.as_retriever(similarity_top_k=1)
recursive_retriever = RecursiveRetriever(
    "vector",
    retriever_dict={"vector": vector_retriever},
    query_engine_dict=agents,
    verbose=True,
    query_response_tmpl="{response}"
)

lock = Lock()

def predict(message):
  print(message)
  lock.acquire()
  try:
    output = recursive_retriever.retrieve(message)[0]
    output = output.get_text()
  except Exception as e:
    print(e)
    raise e
  finally:
    lock.release()
  return output

def getanswer(question, history):
  print("getting answer")
  if hasattr(history, "value"):
    history = history.value
  if hasattr(question, "value"):
    question = question.value

  history = history or []
  lock.acquire()
  try:
    output = recursive_retriever.retrieve(question)[0]
    history.append((question, output.get_text()))
  except Exception as e:
    raise e
  finally:
    lock.release()
  return history, history, gr.update(value="")

with gr.Blocks() as demo:
  with gr.Row():
    with gr.Column(scale=0.75):
      with gr.Row():
        gr.Markdown("<h1>F150 User Manual</h1>")
      chatbot = gr.Chatbot(elem_id="chatbot").style(height=600)

      with gr.Row():
          message = gr.Textbox(
              label="",
              placeholder="F150 User Manual",
              lines=1,
          )
      with gr.Row():
          submit = gr.Button(value="Send", variant="primary", scale=1)

      state = gr.State()
      submit.click(getanswer, inputs=[message, state], outputs=[chatbot, state, message])
      message.submit(getanswer, inputs=[message, state], outputs=[chatbot, state, message])

      predictBtn = gr.Button(value="Predict", visible=False)
      predictBtn.click(predict, inputs=[message], outputs=[message])

demo.launch(debug=True)