from langchain.chains.summarize import load_summarize_chain
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.text_splitter import TokenTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.docstore.document import Document
from langchain.tools import BaseTool, StructuredTool, Tool, tool
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import BaseCallbackManager
from duckduckgo_search import DDGS
from itertools import islice

from typing import Any, Dict, List, Optional, Union

from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult

from pydantic import BaseModel, Field

import requests
from bs4 import BeautifulSoup
from threading import Thread, Condition
from collections import deque

from .base_model import BaseLLMModel, CallbackToIterator, ChuanhuCallbackHandler
from ..config import default_chuanhu_assistant_model
from ..presets import SUMMARIZE_PROMPT, i18n
from ..index_func import construct_index

from langchain.callbacks import get_openai_callback
import os
import gradio as gr
import logging

class GoogleSearchInput(BaseModel):
    keywords: str = Field(description="keywords to search")

class WebBrowsingInput(BaseModel):
    url: str = Field(description="URL of a webpage")

class WebAskingInput(BaseModel):
    url: str = Field(description="URL of a webpage")
    question: str = Field(description="Question that you want to know the answer to, based on the webpage's content.")


class ChuanhuAgent_Client(BaseLLMModel):
    def __init__(self, model_name, openai_api_key, user_name="") -> None:
        super().__init__(model_name=model_name, user=user_name)
        self.text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)
        self.api_key = openai_api_key
        self.llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name=default_chuanhu_assistant_model, openai_api_base=os.environ.get("OPENAI_API_BASE", None))
        self.cheap_llm = ChatOpenAI(openai_api_key=openai_api_key, temperature=0, model_name="gpt-3.5-turbo", openai_api_base=os.environ.get("OPENAI_API_BASE", None))
        PROMPT = PromptTemplate(template=SUMMARIZE_PROMPT, input_variables=["text"])
        self.summarize_chain = load_summarize_chain(self.cheap_llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
        self.index_summary = None
        self.index = None
        if "Pro" in self.model_name:
            tools_to_enable = ["llm-math", "arxiv", "wikipedia"]
            # if exists GOOGLE_CSE_ID and GOOGLE_API_KEY, enable google-search-results-json
            if os.environ.get("GOOGLE_CSE_ID", None) is not None and os.environ.get("GOOGLE_API_KEY", None) is not None:
                tools_to_enable.append("google-search-results-json")
            else:
                logging.warning("GOOGLE_CSE_ID and/or GOOGLE_API_KEY not found, google-search-results-json is disabled.")
            # if exists WOLFRAM_ALPHA_APPID, enable wolfram-alpha
            if os.environ.get("WOLFRAM_ALPHA_APPID", None) is not None:
                tools_to_enable.append("wolfram-alpha")
            else:
                logging.warning("WOLFRAM_ALPHA_APPID not found, wolfram-alpha is disabled.")
            # if exists SERPAPI_API_KEY, enable serpapi
            if os.environ.get("SERPAPI_API_KEY", None) is not None:
                tools_to_enable.append("serpapi")
            else:
                logging.warning("SERPAPI_API_KEY not found, serpapi is disabled.")
            self.tools = load_tools(tools_to_enable, llm=self.llm)
        else:
            self.tools = load_tools(["ddg-search", "llm-math", "arxiv", "wikipedia"], llm=self.llm)
            self.tools.append(
                Tool.from_function(
                    func=self.google_search_simple,
                    name="Google Search JSON",
                    description="useful when you need to search the web.",
                    args_schema=GoogleSearchInput
                )
            )

        self.tools.append(
            Tool.from_function(
                func=self.summary_url,
                name="Summary Webpage",
                description="useful when you need to know the overall content of a webpage.",
                args_schema=WebBrowsingInput
            )
        )

        self.tools.append(
            StructuredTool.from_function(
                func=self.ask_url,
                name="Ask Webpage",
                description="useful when you need to ask detailed questions about a webpage.",
                args_schema=WebAskingInput
            )
        )

    def google_search_simple(self, query):
        results = []
        with DDGS() as ddgs:
            ddgs_gen = ddgs.text(query, backend="lite")
            for r in islice(ddgs_gen, 10):
                results.append({
                    "title": r["title"],
                    "link": r["href"],
                    "snippet": r["body"]
                })
        return str(results)

    def handle_file_upload(self, files, chatbot, language):
        """if the model accepts multi modal input, implement this function"""
        status = gr.Markdown.update()
        if files:
            index = construct_index(self.api_key, file_src=files)
            assert index is not None, "获取索引失败"
            self.index = index
            status = i18n("索引构建完成")
            # Summarize the document
            logging.info(i18n("生成内容总结中……"))
            with get_openai_callback() as cb:
                os.environ["OPENAI_API_KEY"] = self.api_key
                from langchain.chains.summarize import load_summarize_chain
                from langchain.prompts import PromptTemplate
                from langchain.chat_models import ChatOpenAI
                prompt_template = "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN " + language + ":"
                PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
                llm = ChatOpenAI()
                chain = load_summarize_chain(llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=PROMPT, combine_prompt=PROMPT)
                summary = chain({"input_documents": list(index.docstore.__dict__["_dict"].values())}, return_only_outputs=True)["output_text"]
                logging.info(f"Summary: {summary}")
                self.index_summary = summary
                chatbot.append((f"Uploaded {len(files)} files", summary))
            logging.info(cb)
        return gr.Files.update(), chatbot, status

    def query_index(self, query):
        if self.index is not None:
            retriever = self.index.as_retriever()
            qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=retriever)
            return qa.run(query)
        else:
            "Error during query."

    def summary(self, text):
        texts = Document(page_content=text)
        texts = self.text_splitter.split_documents([texts])
        return self.summarize_chain({"input_documents": texts}, return_only_outputs=True)["output_text"]

    def fetch_url_content(self, url):
        response = requests.get(url)
        soup = BeautifulSoup(response.text, 'html.parser')

        # 提取所有的文本
        text = ''.join(s.getText() for s in soup.find_all('p'))
        logging.info(f"Extracted text from {url}")
        return text

    def summary_url(self, url):
        text = self.fetch_url_content(url)
        if text == "":
            return "URL unavailable."
        text_summary = self.summary(text)
        url_content = "webpage content summary:\n" + text_summary

        return url_content

    def ask_url(self, url, question):
        text = self.fetch_url_content(url)
        if text == "":
            return "URL unavailable."
        texts = Document(page_content=text)
        texts = self.text_splitter.split_documents([texts])
        # use embedding
        embeddings = OpenAIEmbeddings(openai_api_key=self.api_key, openai_api_base=os.environ.get("OPENAI_API_BASE", None))

        # create vectorstore
        db = FAISS.from_documents(texts, embeddings)
        retriever = db.as_retriever()
        qa = RetrievalQA.from_chain_type(llm=self.cheap_llm, chain_type="stuff", retriever=retriever)
        return qa.run(f"{question} Reply in 中文")

    def get_answer_at_once(self):
        question = self.history[-1]["content"]
        # llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
        agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
        reply = agent.run(input=f"{question} Reply in 简体中文")
        return reply, -1

    def get_answer_stream_iter(self):
        question = self.history[-1]["content"]
        it = CallbackToIterator()
        manager = BaseCallbackManager(handlers=[ChuanhuCallbackHandler(it.callback)])
        def thread_func():
            tools = self.tools
            if self.index is not None:
                    tools.append(
                        Tool.from_function(
                        func=self.query_index,
                        name="Query Knowledge Base",
                        description=f"useful when you need to know about: {self.index_summary}",
                        args_schema=WebBrowsingInput
                    )
                )
            agent = initialize_agent(self.tools, self.llm, agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)
            try:
                reply = agent.run(input=f"{question} Reply in 简体中文")
            except Exception as e:
                import traceback
                traceback.print_exc()
                reply = str(e)
            it.callback(reply)
            it.finish()
        t = Thread(target=thread_func)
        t.start()
        partial_text = ""
        for value in it:
            partial_text += value
            yield partial_text