from email import message
import re
from turtle import title
from vncorenlp import VnCoreNLP
from nltk.tokenize import sent_tokenize
import torch
from sentence_transformers import SentenceTransformer
import datetime
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import requests
import json
from . import utils
import time
from summary import text_summary, get_summary_bert
from function.clean_text import normalize_text
# from . import detect_time as dt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = SentenceTransformer('model/distiluse-base-multilingual-cased-v2').to(device)
model = SentenceTransformer('model/paraphrase-multilingual-MiniLM-L12-v2')

# model = SentenceTransformer('VoVanPhuc/sup-SimCSE-VietNamese-phobert-base').to(device)
# model.save('model/distiluse-base-multilingual-cased-v2')

use_cuda = torch.cuda.is_available()
print(torch.cuda.is_available())
if torch.cuda.is_available():
    model_en = SentenceTransformer('model/paraphrase-mpnet-base-v2').to(device)
else:
    model_en = model
# model_en.save('model/paraphrase-mpnet-base-v2')
annotator = VnCoreNLP('vncorenlp/VnCoreNLP-1.1.1.jar', port=9191, annotators="wseg,pos", max_heap_size='-Xmx8g')


def detect_postaging(text_in):
    word_segmented_text = annotator.annotate(text_in)
    lst_k = []
    for se in word_segmented_text["sentences"]:
        for kw in se:
            if kw["posTag"] in ("Np", "Ny", "N"):
                if kw["posTag"] == "N" and "_" not in kw["form"]:
                    continue
                lst_k.append(kw["form"].replace("_", " "))
    return list(set(lst_k))

def clean_text(text_in):
    doc = re.sub('<.*?>', '', text_in)
    doc = re.sub('(function).*}', ' ', doc)
    # link
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\/\/)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vn)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.net)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vgp)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.vgp)', ' ', doc)

    doc = re.sub('(http:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(http:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\/\/)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.vn)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.net)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.vgp)', ' ', doc)
    doc = re.sub('(http:\/\/).*?(\.vgp)', ' ', doc)
    # escape sequence
    doc = re.sub('\n', ' ', doc)
    doc = re.sub('\t', ' ', doc)
    doc = re.sub('\r', ' ', doc)

    doc = normalize_text(doc)
    return doc


def data_cleaning(docs):
    res = []
    for d in docs:
        if 'message' in d:
            # css and js
            doc = re.sub('<.*?>', '', d['message'])
            doc = re.sub('(function).*}', ' ', doc)

            # link
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\/\/)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vn)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.net)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vgp)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.vgp)', ' ', doc)

            doc = re.sub('(http:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(http:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\/\/)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.vn)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.net)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.vgp)', ' ', doc)
            doc = re.sub('(http:\/\/).*?(\.vgp)', ' ', doc)
            # escape sequence
            doc = re.sub('\n', ' ', doc)
            doc = re.sub('\t', ' ', doc)
            doc = re.sub('\r', ' ', doc)

            d['message'] = doc
            res.append(d)
    return res


def segment(docs, lang="vi"):
    segmented_docs = []
    for d in docs:
        # if len(d.get('message', "")) > 8000 or len(d.get('message', "")) < 100:
        if len(d.get('message', "")) > 8000:
            continue
        if 'snippet' not in d:
            continue
        try:
            if lang == "vi":
                snippet = d.get('snippet', "")
                segmented_snippet = ""
                segmented_sentences_snippet = annotator.tokenize(snippet)
                for sentence in segmented_sentences_snippet:
                    segmented_snippet += ' ' + ' '.join(sentence)
                segmented_snippet = segmented_snippet.replace('\xa0', '')
                d['segmented_snippet'] = segmented_snippet
            segmented_docs.append(d)
        except Exception:
            pass
    return segmented_docs


def timestamp_to_date(timestamp):
    return datetime.datetime.fromtimestamp(timestamp).strftime('%d/%m/%Y')


def post_processing(response, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster = 50, delete_message=True):
    print(f'[INFO] sorted_field: {sorted_field}')
    MAX_DOC_PER_CLUSTER = max_doc_per_cluster

    lst_ids = []
    lst_top = []
    lst_res = []
    for i in response:
        lst_ids.append(i)

        if  not sorted_field.strip():
            lst_top.append(len(response[i]))
        else:
            lst_top.append(response[i][0]['max_score'])
        
    idx = np.argsort(np.array(lst_top))[::-1] 
    if top_cluster == -1:
        top_cluster = len(idx)
    for i in idx[: top_cluster]:
        ik = lst_ids[i]
        if top_sentence == -1:
            top_sentence = len(response[ik])
        lst_check_title = []
        lst_check_not_title = []
        i_c_t = 0
        for resss in response[ik]:
            r_title = resss.get("title", "")
            if r_title and not r_title.endswith("..."):
                lst_check_title.append(resss)
                i_c_t += 1
            else:
                lst_check_not_title.append(resss)
            if i_c_t == top_sentence:
                break
        if i_c_t == top_sentence:
            lst_res.append(lst_check_title)
        else:
            lst_check_title.extend(lst_check_not_title)
            lst_res.append(lst_check_title[:top_sentence])
        #lst_res.append(response[ik][:top_sentence])
    dict_res = {}
    for i in range(len(lst_res)):
        dict_res[str(i + 1)] = lst_res[i][:MAX_DOC_PER_CLUSTER]
        for j in range(min(len(dict_res[str(i + 1)]), 3)):
            dict_res[str(i + 1)][0]["title_summarize"].append(dict_res[str(i + 1)][j].get("snippet", ""))
        summary_text = get_summary_bert(dict_res[str(i + 1)][0].get("message", ""), lang=get_summary_bert(dict_res[str(i + 1)][0].get("lang", "vi"), topn=topn_summary, title=dict_res[str(i + 1)][0].get("title", ""), snippet=dict_res[str(i + 1)][0].get("snippet", ""))
        if len(summary_text) < 10:
            summary_text = dict_res[str(i + 1)][0].get("snippet", "")
            if len(summary_text) < 10:
                summary_text = dict_res[str(i + 1)][0].get("title", "")
        dict_res[str(i + 1)][0]["content_summary"] = utils.remove_image_keyword(summary_text)
        kew_phares = []
        dict_res[str(i + 1)][0]["topic_keywords"] = kew_phares

        if delete_message:
            for j in range(len(dict_res[str(i + 1)])):
                if "message" in dict_res[str(i + 1)][j]:
                    del dict_res[str(i + 1)][j]["message"]
    return dict_res


def get_lang(docs):
    lang_vi = 0
    lang_en = 0
    for d in docs:
        if d.get("lang", "") == "vi":
            lang_vi += 1
        else:
            lang_en += 1
    if lang_vi >= lang_en:
        return "vi"
    return "en"


# def topic_clustering(docs, distance_threshold, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field=''):
#     global model, model_en
#     docs = docs[:30000]
#     lang = get_lang(docs)
#     result = {}
#     docs = segment(docs, lang=lang)
#     if len(docs) < 2:
#         return result
#     if lang == "vi":
#         features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
#         vectors = model.encode(features, show_progress_bar=False)
#     else:
#         features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
#         vectors = model_en.encode(features, show_progress_bar=False)
#     clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
#                                         linkage='single', distance_threshold=distance_threshold)
#     clusteror.fit(vectors)
#     print(clusteror.n_clusters_)
#     for i in range(clusteror.n_clusters_):
#         result[str(i + 1)] = []
#     for i in range(len(clusteror.labels_)):
#         cluster_no = clusteror.labels_[i]   
#         if docs[i].get('domain','') not in ["cungcau.vn","baomoi.com","news.skydoor.net"]:
#             response_doc = {}
#             response_doc = docs[i]
#             if 'domain' in docs[i]:
#                 response_doc['domain'] = docs[i]['domain']
#             if 'url' in docs[i]:
#                 response_doc['url'] = docs[i]['url']
#             if 'title' in docs[i]:
#                 response_doc['title'] = clean_text(docs[i]['title'])
#             if 'snippet' in docs[i]:
#                 response_doc['snippet'] = clean_text(docs[i]['snippet'])
#             if 'created_time' in docs[i]:
#                 response_doc['created_time'] = docs[i]['created_time']
#             if 'message' in docs[i]:
#                 title = docs[i].get('title','')
#                 snippet = docs[i].get('snippet','')
#                 message = docs[i].get('message','')
#                 if title.strip():
#                     split_mess = message.split(title)
#                     if len(split_mess) > 1:
#                         message = title.join(split_mess[1:])
#                 if snippet.strip():
#                     split_mess = message.split(snippet)
#                     if len(split_mess) > 1:
#                         message = snippet.join(split_mess[1:])

#                 response_doc['message'] = clean_text(message)
#             if 'id' in docs[i]:
#                 response_doc['id'] = docs[i]['id']
#             response_doc['score'] = 0.0
#             response_doc['title_summarize'] = []
#             response_doc['content_summary'] = ""
#             response_doc['total_facebook_viral'] = 0
#             result[str(cluster_no + 1)].append(response_doc)
    
#     empty_clus_ids = []
#     for x in result:
#         result[x] = sorted(result[x], key=lambda i: -len(i.get('message','')))
#         if len( result[x]) > 0:
#             if len(result[x]) > 1:
#                 result[x] = check_duplicate_title_domain(result[x])
#             result[x][0]['num_docs'] = len(result[x])
#         else:
#             empty_clus_ids.append(x)
    
#     for x in empty_clus_ids:
#         result.pop(x,None)
#     # result = dict(sorted(result.items(), key=lambda i: -len(i[1])))[:top_cluster]
#     return post_processing(result, top_cluster=top_cluster, top_sentence=top_sentence, topn_summary=topn_summary, sorted_field = sorted_field)

def topic_clustering(docs, distance_threshold, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, delete_message=True):
    global model, model_en
    docs = docs[:30000]
    lang = get_lang(docs)
    result = {}
    cluster_score = {}
    docs = segment(docs, lang=lang)
    if len(docs) < 2:
        return result
    if lang == "vi":
        features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
        vectors = model.encode(features, show_progress_bar=False)
    else:
        features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
        vectors = model_en.encode(features, show_progress_bar=False)
    clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
                                        linkage='single', distance_threshold=distance_threshold)
    clusteror.fit(vectors)
    print(clusteror.n_clusters_)
    for i in range(clusteror.n_clusters_):
        result[str(i + 1)] = []
        cluster_score[str(i + 1)] = 0
    for i in range(len(clusteror.labels_)):
        cluster_no = clusteror.labels_[i]   
        if docs[i].get('domain','') not in ["cungcau.vn","baomoi.com","news.skydoor.net"]:
            response_doc = {}
            response_doc = docs[i]
            score = response_doc.get('score', 0)
            if not docs[i].get('message','').strip():
                continue 
            if score > cluster_score[str(cluster_no + 1)]:
                cluster_score[str(cluster_no + 1)] = score
            if 'domain' in docs[i]:
                response_doc['domain'] = docs[i]['domain']
            if 'url' in docs[i]:
                response_doc['url'] = docs[i]['url']
            if 'title' in docs[i]:
                response_doc['title'] = clean_text(docs[i]['title'])
            if 'snippet' in docs[i]:
                response_doc['snippet'] = clean_text(docs[i]['snippet'])
            if 'created_time' in docs[i]:
                response_doc['created_time'] = docs[i]['created_time']
            if 'message' in docs[i]:
                title = docs[i].get('title','')
                snippet = docs[i].get('snippet','')
                message = docs[i].get('message','')
                if title.strip():
                    split_mess = message.split(title)
                    if len(split_mess) > 1:
                        message = title.join(split_mess[1:])
                if snippet.strip():
                    split_mess = message.split(snippet)
                    if len(split_mess) > 1:
                        message = snippet.join(split_mess[1:])

                response_doc['message'] = clean_text(message)
            if 'id' in docs[i]:
                response_doc['id'] = docs[i]['id']
            # response_doc['score'] = 0.0
            response_doc['title_summarize'] = []
            response_doc['content_summary'] = ""
            response_doc['total_facebook_viral'] = 0
            result[str(cluster_no + 1)].append(response_doc)
    
    empty_clus_ids = []
    for x in result:
        result[x] = sorted(result[x], key=lambda i: -len(i.get('message','')))
        if len( result[x]) > 0:
            if len(result[x]) > 1:
                result[x] = check_duplicate_title_domain(result[x])
            result[x][0]['num_docs'] = len(result[x])
            result[x][0]['max_score'] = cluster_score[x]
        else:
            empty_clus_ids.append(x)
    
    for x in empty_clus_ids:
        result.pop(x,None)
    # result = dict(sorted(result.items(), key=lambda i: -len(i[1])))[:top_cluster]
    return post_processing(result, top_cluster=top_cluster, top_sentence=top_sentence, topn_summary=topn_summary, sorted_field = sorted_field, max_doc_per_cluster=max_doc_per_cluster, delete_message=delete_message)

def check_duplicate_title_domain(docs):
    lst_title_domain = [f"{d.get('domain', '')} {d.get('title','')}" for d in docs]
    for i in range(1,len(lst_title_domain) -1):
        for j in range(i+1,len(lst_title_domain)):
            if lst_title_domain[j] == lst_title_domain[i]:
                lst_title_domain[j] = 'dup'
    lst_filter_docs = [docs[i] for i,x in enumerate(lst_title_domain) if x != 'dup']
    return lst_filter_docs
def convert_date(text):
    text = text.replace(".", "/")
    text = text.replace("-", "/")
    return text


def check_keyword(sentence):
    keyword = ['sáng', 'trưa', 'chiều', 'tối', 'đến', 'hôm', 'ngày', 'tới']
    for k in keyword:
        if k in sentence:
            return True
    return False


def extract_events_and_time(docs, publish_date):
    def standardize(date_str):
        return date_str.replace('.', '/').replace('-', '/')

    def add_0(date_str):

        date_str = date_str.split('/')
        res = []
        for o in date_str:
            o = re.sub('\s+', '', o)
            if len(o) < 2:
                o = '0' + o
            res.append(o)
        date_str = '/'.join(res)
        return date_str

    def get_date_list(reg, sentence):
        find_object = re.finditer(reg, sentence)
        date_list = [x.group() for x in find_object]
        return date_list

    year = publish_date.split('/')[2]

    # dd/mm/yyyy
    reg_exp_1 = '(\D|^)(?:0?[1-9]|[12][0-9]|3[01])[- \/.](?:0?[1-9]|1[012])[- \/.]([12]([0-9]){3})(\D|$)'
    # #mm/yyyy
    # reg_exp_5 = '(\D|^)(?:0?[1-9]|1[012])[- \/.]([12]([0-9]){3})(\D|$)'
    # dd/mm
    reg_exp_2 = '(\D|^)(?:0?[1-9]|[12][0-9]|3[01])[- \/.](?:0?[1-9]|1[012])(\D|$)'

    # ngày  dd tháng mm năm yyyy
    reg_exp_3 = '(ngày)\s*\d{1,2}\s*(tháng)\s*\d{1,2}\s*(năm)\s*\d{4}'
    # ngày dd tháng mm
    reg_exp_4 = '(ngày)\s*\d{1,2}\s*(tháng)\s*\d{1,2}'

    result = []
    for d in docs:
        text = d['message']
        for sentence in sent_tokenize(text):
            lower_sentence = sentence.lower()
            c = re.search(reg_exp_3, sentence.lower())
            d = re.search(reg_exp_4, sentence.lower())
            # e = re.search(reg_exp_5, sentence.lower())
            a = re.search(reg_exp_1, sentence)
            b = re.search(reg_exp_2, sentence)
            #
            if (a or b or c or d) and check_keyword(lower_sentence):
                date_list = get_date_list(reg_exp_1, lower_sentence)
                date_entity = ''
                if date_list:
                    date_entity = add_0(standardize(date_list[0]))
                elif get_date_list(reg_exp_2, lower_sentence):
                    date_list = get_date_list(reg_exp_2, lower_sentence)
                    date_entity = add_0(standardize(date_list[0]) + '/' + year)
                elif get_date_list(reg_exp_3, lower_sentence):
                    date_list = get_date_list(reg_exp_3, lower_sentence)

                    date_entity = date_list[0].replace('ngày', '').replace('tháng', '').replace('năm', '').strip()
                    date_entity = re.sub('\s+', ' ', date_entity)
                    date_entity = date_entity.replace(' ', '/')
                    date_entity = add_0(date_entity)
                else:
                    date_list = get_date_list(reg_exp_4, lower_sentence)
                    if date_list != []:
                        date_entity = date_list[0].replace('ngày', '').replace('tháng', '').replace('năm', '').strip()
                        date_entity = re.sub('\s+', ' ', date_entity)
                        date_entity = date_entity.replace(' ', '/')
                        date_entity = date_entity + '/' + year
                        date_entity = add_0(date_entity)
                result.append((sentence, date_entity))
    return result