import json
import tempfile

import requests
import streamlit as st
from lagent.schema import AgentStatusCode
from pyvis.network import Network


# Function to create the network graph
def create_network_graph(nodes, adjacency_list):
    net = Network(height='500px',
                  width='60%',
                  bgcolor='white',
                  font_color='black')
    for node_id, node_data in nodes.items():
        if node_id in ['root', 'response']:
            title = node_data.get('content', node_id)
        else:
            title = node_data['detail']['content']
        net.add_node(node_id,
                     label=node_id,
                     title=title,
                     color='#FF5733',
                     size=25)
    for node_id, neighbors in adjacency_list.items():
        for neighbor in neighbors:
            if neighbor['name'] in nodes:
                net.add_edge(node_id, neighbor['name'])
    net.show_buttons(filter_=['physics'])
    return net


# Function to draw the graph and return the HTML file path
def draw_graph(net):
    path = tempfile.mktemp(suffix='.html')
    net.save_graph(path)
    return path


def streaming(raw_response):
    for chunk in raw_response.iter_lines(chunk_size=8192,
                                         decode_unicode=False,
                                         delimiter=b'\n'):
        if chunk:
            decoded = chunk.decode('utf-8')
            if decoded == '\r':
                continue
            if decoded[:6] == 'data: ':
                decoded = decoded[6:]
            elif decoded.startswith(': ping - '):
                continue
            response = json.loads(decoded)
            yield (response['response'], response['current_node'])


# Initialize Streamlit session state
if 'queries' not in st.session_state:
    st.session_state['queries'] = []
    st.session_state['responses'] = []
    st.session_state['graphs_html'] = []
    st.session_state['nodes_list'] = []
    st.session_state['adjacency_list_list'] = []
    st.session_state['history'] = []
    st.session_state['already_used_keys'] = list()

# Set up page layout
st.set_page_config(layout='wide')
st.title('MindSearch-思索')


# Function to update chat
def update_chat(query):
    with st.chat_message('user'):
        st.write(query)
    if query not in st.session_state['queries']:
        # Mock data to simulate backend response
        # response, history, nodes, adjacency_list
        st.session_state['queries'].append(query)
        st.session_state['responses'].append([])
        history = None
        # 暂不支持多轮
        message = [dict(role='user', content=query)]

        url = 'http://localhost:8002/solve'
        headers = {'Content-Type': 'application/json'}
        data = {'inputs': message}
        raw_response = requests.post(url,
                                     headers=headers,
                                     data=json.dumps(data),
                                     timeout=20,
                                     stream=True)

        for resp in streaming(raw_response):
            agent_return, node_name = resp
            if node_name and node_name in ['root', 'response']:
                continue
            nodes = agent_return['nodes']
            adjacency_list = agent_return['adj']
            response = agent_return['response']
            history = agent_return['inner_steps']
            if nodes:
                net = create_network_graph(nodes, adjacency_list)
                graph_html_path = draw_graph(net)
                with open(graph_html_path, encoding='utf-8') as f:
                    graph_html = f.read()
            else:
                graph_html = None
            if 'graph_placeholder' not in st.session_state:
                st.session_state['graph_placeholder'] = st.empty()
            if 'expander_placeholder' not in st.session_state:
                st.session_state['expander_placeholder'] = st.empty()
            if graph_html:
                with st.session_state['expander_placeholder'].expander(
                        'Show Graph', expanded=False):
                    st.session_state['graph_placeholder']._html(graph_html,
                                                                height=500)
            if 'container_placeholder' not in st.session_state:
                st.session_state['container_placeholder'] = st.empty()
            with st.session_state['container_placeholder'].container():
                if 'columns_placeholder' not in st.session_state:
                    st.session_state['columns_placeholder'] = st.empty()
                col1, col2 = st.session_state['columns_placeholder'].columns(
                    [2, 1])
                with col1:
                    if 'planner_placeholder' not in st.session_state:
                        st.session_state['planner_placeholder'] = st.empty()
                    if 'session_info_temp' not in st.session_state:
                        st.session_state['session_info_temp'] = ''
                    if not node_name:
                        if agent_return['state'] in [
                                AgentStatusCode.STREAM_ING,
                                AgentStatusCode.ANSWER_ING
                        ]:
                            st.session_state['session_info_temp'] = response
                        elif agent_return[
                                'state'] == AgentStatusCode.PLUGIN_START:
                            thought = st.session_state[
                                'session_info_temp'].split('```')[0]
                            if agent_return['response'].startswith('```'):
                                st.session_state[
                                    'session_info_temp'] = thought + '\n' + response
                        elif agent_return[
                                'state'] == AgentStatusCode.PLUGIN_RETURN:
                            assert agent_return['inner_steps'][-1][
                                'role'] == 'environment'
                            st.session_state[
                                'session_info_temp'] += '\n' + agent_return[
                                    'inner_steps'][-1]['content']
                        st.session_state['planner_placeholder'].markdown(
                            st.session_state['session_info_temp'])
                        if agent_return[
                                'state'] == AgentStatusCode.PLUGIN_RETURN:
                            st.session_state['responses'][-1].append(
                                st.session_state['session_info_temp'])
                            st.session_state['session_info_temp'] = ''
                    else:
                        st.session_state['planner_placeholder'].markdown(
                            st.session_state['responses'][-1][-1] if
                            not st.session_state['session_info_temp'] else st.
                            session_state['session_info_temp'])
                with col2:
                    if 'selectbox_placeholder' not in st.session_state:
                        st.session_state['selectbox_placeholder'] = st.empty()
                    if 'searcher_placeholder' not in st.session_state:
                        st.session_state['searcher_placeholder'] = st.empty()
                    # st.session_state['searcher_placeholder'].markdown('')
                    if node_name:
                        selected_node_key = f"selected_node_{len(st.session_state['queries'])}_{node_name}"
                        if selected_node_key not in st.session_state:
                            st.session_state[selected_node_key] = node_name
                        if selected_node_key not in st.session_state[
                                'already_used_keys']:
                            selected_node = st.session_state[
                                'selectbox_placeholder'].selectbox(
                                    'Select a node:',
                                    list(nodes.keys()),
                                    key=f'key_{selected_node_key}',
                                    index=list(nodes.keys()).index(node_name))
                            st.session_state['already_used_keys'].append(
                                selected_node_key)
                        else:
                            selected_node = node_name
                        st.session_state[selected_node_key] = selected_node
                        if selected_node in nodes:
                            node = nodes[selected_node]
                            agent_return = node['detail']
                            node_info_key = f'{selected_node}_info'
                            if 'node_info_temp' not in st.session_state:
                                st.session_state[
                                    'node_info_temp'] = f'### {agent_return["content"]}'
                            if node_info_key not in st.session_state:
                                st.session_state[node_info_key] = []
                            if agent_return['state'] in [
                                    AgentStatusCode.STREAM_ING,
                                    AgentStatusCode.ANSWER_ING
                            ]:
                                st.session_state[
                                    'node_info_temp'] = agent_return[
                                        'response']
                            elif agent_return[
                                    'state'] == AgentStatusCode.PLUGIN_START:
                                thought = st.session_state[
                                    'node_info_temp'].split('```')[0]
                                if agent_return['response'].startswith('```'):
                                    st.session_state[
                                        'node_info_temp'] = thought + '\n' + agent_return[
                                            'response']
                            elif agent_return[
                                    'state'] == AgentStatusCode.PLUGIN_END:
                                thought = st.session_state[
                                    'node_info_temp'].split('```')[0]
                                if isinstance(agent_return['response'], dict):
                                    st.session_state[
                                        'node_info_temp'] = thought + '\n' + f'```json\n{json.dumps(agent_return["response"], ensure_ascii=False, indent=4)}\n```'  # noqa: E501
                            elif agent_return[
                                    'state'] == AgentStatusCode.PLUGIN_RETURN:
                                assert agent_return['inner_steps'][-1][
                                    'role'] == 'environment'
                                st.session_state[node_info_key].append(
                                    ('thought',
                                     st.session_state['node_info_temp']))
                                st.session_state[node_info_key].append(
                                    ('observation',
                                     agent_return['inner_steps'][-1]['content']
                                     ))
                            st.session_state['searcher_placeholder'].markdown(
                                st.session_state['node_info_temp'])
                            if agent_return['state'] == AgentStatusCode.END:
                                st.session_state[node_info_key].append(
                                    ('answer',
                                     st.session_state['node_info_temp']))
                                st.session_state['node_info_temp'] = ''
        if st.session_state['session_info_temp']:
            st.session_state['responses'][-1].append(
                st.session_state['session_info_temp'])
            st.session_state['session_info_temp'] = ''
        # st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1])
        st.session_state['graphs_html'].append(graph_html)
        st.session_state['nodes_list'].append(nodes)
        st.session_state['adjacency_list_list'].append(adjacency_list)
        st.session_state['history'] = history


def display_chat_history():
    for i, query in enumerate(st.session_state['queries'][-1:]):
        # with st.chat_message('assistant'):
        if st.session_state['graphs_html'][i]:
            with st.session_state['expander_placeholder'].expander(
                    'Show Graph', expanded=False):
                st.session_state['graph_placeholder']._html(
                    st.session_state['graphs_html'][i], height=500)
            with st.session_state['container_placeholder'].container():
                col1, col2 = st.session_state['columns_placeholder'].columns(
                    [2, 1])
                with col1:
                    st.session_state['planner_placeholder'].markdown(
                        st.session_state['responses'][-1][-1])
                with col2:
                    selected_node_key = st.session_state['already_used_keys'][
                        -1]
                    st.session_state['selectbox_placeholder'] = st.empty()
                    selected_node = st.session_state[
                        'selectbox_placeholder'].selectbox(
                            'Select a node:',
                            list(st.session_state['nodes_list'][i].keys()),
                            key=f'replay_key_{i}',
                            index=list(st.session_state['nodes_list'][i].keys(
                            )).index(st.session_state[selected_node_key]))
                    st.session_state[selected_node_key] = selected_node
                    if selected_node not in [
                            'root', 'response'
                    ] and selected_node in st.session_state['nodes_list'][i]:
                        node_info_key = f'{selected_node}_info'
                        for item in st.session_state[node_info_key]:
                            if item[0] in ['thought', 'answer']:
                                st.session_state[
                                    'searcher_placeholder'] = st.empty()
                                st.session_state[
                                    'searcher_placeholder'].markdown(item[1])
                            elif item[0] == 'observation':
                                st.session_state[
                                    'observation_expander'] = st.empty()
                                with st.session_state[
                                        'observation_expander'].expander(
                                            'Results'):
                                    st.write(item[1])
                        # st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key])


def clean_history():
    st.session_state['queries'] = []
    st.session_state['responses'] = []
    st.session_state['graphs_html'] = []
    st.session_state['nodes_list'] = []
    st.session_state['adjacency_list_list'] = []
    st.session_state['history'] = []
    st.session_state['already_used_keys'] = list()
    for k in st.session_state:
        if k.endswith('placeholder') or k.endswith('_info'):
            del st.session_state[k]


# Main function to run the Streamlit app
def main():
    st.sidebar.title('Model Control')
    col1, col2 = st.columns([4, 1])
    with col1:
        user_input = st.chat_input('Enter your query:')
    with col2:
        if st.button('Clear History'):
            clean_history()
    if user_input:
        update_chat(user_input)
    display_chat_history()


if __name__ == '__main__':
    main()