import json
import os
import re
import builtins
import shutil
import uuid
from functools import wraps
import streamlit as st
import pandas as pd
from custom import *


# 聊天记录处理
def clear_folder(path):
    if not os.path.exists(path):
        return
    for file_name in os.listdir(path):
        file_path = os.path.join(path, file_name)
        try:
            shutil.rmtree(file_path)
        except Exception:
            pass


def set_chats_path():
    save_path = 'chat_history'
    if 'apikey' not in st.secrets:
        clear_folder('tem_files')
        save_path = 'tem_files/tem_chat' + str(uuid.uuid4())
    return save_path


# 重新open函数,路径不存在时自动创建
def create_path(func):
    @wraps(func)
    def wrapper(path, *args, **kwargs):
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))
        return func(path, *args, **kwargs)

    return wrapper


open = create_path(builtins.open)


def get_history_chats(path):
    try:
        os.makedirs(path)
    except FileExistsError:
        pass
    files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
    files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
    sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
    chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
    if len(chat_names) == 0:
        chat_names.append('New Chat_' + str(uuid.uuid4()))
    return chat_names


def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
    with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
        json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)


def remove_data(path: str, file_name: str):
    try:
        os.remove(f"./{path}/{file_name}.json")
    except FileNotFoundError:
        pass


def load_data(path: str, file_name: str) -> dict:
    try:
        with open(f"./{path}/{file_name}.json", 'r', encoding='utf-8') as f:
            data = json.load(f)
            return data
    except FileNotFoundError:
        with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
            f.write(json.dumps(initial_content_all))
        return initial_content_all


def show_each_message(message, role, area=None):
    if area is None:
        area = [st.markdown] * 2
    if role == 'user':
        icon = user_svg
        name = user_name
        background_color = user_background_color
    else:
        icon = gpt_svg
        name = gpt_name
        background_color = gpt_background_color
    area[0](f"\n<div class='avatar'>{icon}<h2>{name}:</h2></div>", unsafe_allow_html=True)
    #area[1](f"""<div class='content-div' style='background-color: {background_color};'>\n\n{message}""",
    #        unsafe_allow_html=True)
    area[1](f"""<div class='content-div'>\n\n{message}</div>""", unsafe_allow_html=True)


def show_messages(messages: list):
    for each in messages:
        if (each["role"] == "user") or (each["role"] == "assistant"):
            show_each_message(each["content"], each["role"])
        if each["role"] == "assistant":
            st.write("---")


# 根据context_level提取history
def get_history_input(history, level):
    df_history = pd.DataFrame(history)
    df_system = df_history.query('role=="system"')
    df_input = df_history.query('role!="system"')
    df_input = df_input[-level * 2:]
    res = pd.concat([df_system, df_input], ignore_index=True).to_dict('records')
    return res


# 去除#号右边的空格
def remove_hashtag_right__space(text):
    res = re.sub(r"(#+)\s*", r"\1", text)
    return res


# 提取文本
def extract_chars(text, num):
    char_num = 0
    chars = ''
    for char in text:
        # 汉字算两个字符
        if '\u4e00' <= char <= '\u9fff':
            char_num += 2
        else:
            char_num += 1
        chars += char
        if char_num >= num:
            break
    return chars