from tensorRT import inference
import re
from collections import Counter
from vncorenlp import VnCoreNLP
from nltk.tokenize import sent_tokenize
import torch
import datetime
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import json
from . import utils
import time
from summary import text_summary, get_summary_bert
from function.clean_text import normalize_text
from .summary_with_llm import summary_with_llama
from .translate import translate_text_multi_layer
from scipy.spatial import distance
import copy
from .sentence_embbeding import embbeded_zh, embbeded_en, embedded_bge


# from . import detect_time as dt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_cuda = torch.cuda.is_available()
print(torch.cuda.is_available())

# annotator = VnCoreNLP('vncorenlp/VnCoreNLP-1.1.1.jar', port=9191, annotators="wseg,pos", max_heap_size='-Xmx2g')


def detect_postaging(text_in):
    word_segmented_text = annotator.annotate(text_in)
    lst_k = []
    for se in word_segmented_text["sentences"]:
        for kw in se:
            if kw["posTag"] in ("Np", "Ny", "N"):
                if kw["posTag"] == "N" and "_" not in kw["form"]:
                    continue
                lst_k.append(kw["form"].replace("_", " "))
    return list(set(lst_k))

def clean_text(text_in):
    doc = re.sub('<.*?>', '', text_in)
    doc = re.sub('(function).*}', ' ', doc)
    # link
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\/\/)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vn)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.net)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vgp)', ' ', doc)
    doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.vgp)', ' ', doc)

    doc = re.sub('(http:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(http:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\/\/)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.htm)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.html)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.vn)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.net)', ' ', doc)
    doc = re.sub('(https:\/\/).*?(\.vgp)', ' ', doc)
    doc = re.sub('(http:\/\/).*?(\.vgp)', ' ', doc)
    # escape sequence
    doc = re.sub('\n', ' ', doc)
    doc = re.sub('\t', ' ', doc)
    doc = re.sub('\r', ' ', doc)

    doc = normalize_text(doc)
    return doc


def data_cleaning(docs):
    res = []
    for d in docs:
        if 'message' in d:
            # css and js
            doc = re.sub('<.*?>', '', d['message'])
            doc = re.sub('(function).*}', ' ', doc)

            # link
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\/\/)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vn)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.net)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(https:\/\/).*?(\.vgp)', ' ', doc)
            doc = re.sub('(Nguồn)\s*?(http:\/\/).*?(\.vgp)', ' ', doc)

            doc = re.sub('(http:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(http:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\/\/)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.htm)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.html)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.vn)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.net)', ' ', doc)
            doc = re.sub('(https:\/\/).*?(\.vgp)', ' ', doc)
            doc = re.sub('(http:\/\/).*?(\.vgp)', ' ', doc)
            # escape sequence
            doc = re.sub('\n', ' ', doc)
            doc = re.sub('\t', ' ', doc)
            doc = re.sub('\r', ' ', doc)

            d['message'] = doc
            res.append(d)
    return res


def segment(docs, lang="vi"):
    segmented_docs = []
    for d in docs:
        # print(d)
        # if len(d.get('message', "")) > 8000 or len(d.get('message', "")) < 100:
        if len(d.get('message', "")) > 8000:
            continue
        if 'snippet' not in d:
            continue
        try:
            if lang == "vi":
                snippet = d.get('snippet', "")
                segmented_snippet = ""
                segmented_sentences_snippet = annotator.tokenize(snippet)
                for sentence in segmented_sentences_snippet:
                    segmented_snippet += ' ' + ' '.join(sentence)
                segmented_snippet = segmented_snippet.replace('\xa0', '')
                d['segmented_snippet'] = segmented_snippet
            segmented_docs.append(d)
        except Exception:
            pass
    return segmented_docs


def timestamp_to_date(timestamp):
    return datetime.datetime.fromtimestamp(timestamp).strftime('%d/%m/%Y')


def re_ranking(result_topic, vectors_prompt, sorted_field):
    lst_score = []
    lst_ids = []
    lst_top = []
    try:
        for k in result_topic:
            lst_ids.append(k)
            if  not sorted_field.strip():
                    lst_top.append(len(result_topic[k]))
            else:
                lst_top.append(result_topic[k][0]['max_score'])
            vector_center = result_topic[k][0]["vector"]
            max_score = 11.0
            for vec in vectors_prompt:
                score = distance.cosine(np.array(vec), np.array(vector_center))
                if score < max_score:
                    max_score = score
            lst_score.append(max_score)
            result_topic[k][0]["similarity_score"] = max_score
            for d in result_topic[k]:
                d["similarity_score"] = max_score
            del result_topic[k][0]["vector"]
        idx = np.argsort(np.array(lst_score))
    except Exception as ve:
        return [], lst_ids, lst_top
    return idx, lst_ids, lst_top

def post_processing(response, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster = 50, delete_message=True, prompt="", hash_str: str= "", vectors_prompt: list = []):
    print(f'[INFO] sorted_field: {sorted_field}')
    MAX_DOC_PER_CLUSTER = max_doc_per_cluster

    lst_ids = []
    lst_top = []
    lst_res = []
    idx = []
    if prompt:
        idx, lst_ids, lst_top = re_ranking(response, vectors_prompt, sorted_field)
        print("idx_prompt: ", idx)
    if len(prompt) == 0 or len(idx) == 0:
        for i in response:
            lst_ids.append(i)
            if  not sorted_field.strip():
                lst_top.append(len(response[i]))
            else:
                lst_top.append(response[i][0]['max_score'])
        idx = np.argsort(np.array(lst_top))[::-1] 
        print("idx_not_prompt: ", idx)
    if top_cluster == -1:
        top_cluster = len(idx)
    for i in idx[: top_cluster]:
        ik = lst_ids[i]
        if top_sentence == -1:
            top_sentence = len(response[ik])
        lst_check_title = []
        lst_check_not_title = []
        i_c_t = 0
        for resss in response[ik]:
            r_title = resss.get("title", "")
            if r_title and not r_title.endswith("..."):
                lst_check_title.append(resss)
                i_c_t += 1
            else:
                lst_check_not_title.append(resss)
            if i_c_t == top_sentence:
                break
        if i_c_t == top_sentence:
            lst_res.append(lst_check_title)
        else:
            lst_check_title.extend(lst_check_not_title)
            lst_res.append(lst_check_title[:top_sentence])
        #lst_res.append(response[ik][:top_sentence])
    dict_res = {}
    for i in range(len(lst_res)):
        dict_res[str(i + 1)] = lst_res[i][:MAX_DOC_PER_CLUSTER]
        for j in range(min(len(dict_res[str(i + 1)]), 3)):
            dict_res[str(i + 1)][0]["title_summarize"].append(dict_res[str(i + 1)][j].get("snippet", ""))
        # t11 = time.time()
        summary_text = get_summary_bert(dict_res[str(i + 1)][0].get("message", ""), dict_res[str(i + 1)][0].get("lang", "vi"), topn=topn_summary, title=dict_res[str(i + 1)][0].get("title", ""), snippet=dict_res[str(i + 1)][0].get("snippet", ""))
        # print("time_summary: ", time.time() - t11)
        if len(summary_text) < 10:
            summary_text = dict_res[str(i + 1)][0].get("snippet", "")
            if len(summary_text) < 10:
                summary_text = dict_res[str(i + 1)][0].get("title", "")
        summary_text = utils.remove_image_keyword(summary_text)
        # if prompt:
        #     if dict_res[str(i + 1)][0].get("message", ""):
        #         src_lang = dict_res[str(i + 1)][0].get("lang", "")
        #         print("src_lang: ", src_lang)
        #         print("summary_text: ", summary_text)
        #         summary_text = translate_text_multi_layer(src_lang, "vi", summary_text)
        #         text_tran = translate_text_multi_layer(src_lang, "vi", dict_res[str(i + 1)][0].get("message", ""))
        #         ans_from_llama = summary_with_llama(prompt, text_tran, "vi", version="vi-llama", max_word_per_context=1000)
        #         print("ans_from_llama: ", ans_from_llama)
        #     summary_text = summary_text + "$$$$\n" + ans_from_llama
        #     print("summary_text: ", summary_text, len(summary_text))
        dict_res[str(i + 1)][0]["content_summary"] = summary_text
        dict_res[str(i + 1)][0]["num_of_post"] = len(lst_res[i])
        kew_phares = []
        dict_res[str(i + 1)][0]["topic_keywords"] = kew_phares
        
        # print("delete_message: ", delete_message)
        if delete_message:
            for j in range(len(dict_res[str(i + 1)])):
                if "message" in dict_res[str(i + 1)][j]:
                    del dict_res[str(i + 1)][j]["message"]
    
    with open(f"log_llm/topic_result_after_postprocessing/{hash_str}.json", "w") as f:
        dict_log_pos = {}
        for k in dict_res:
            dict_log_pos[k] = copy.deepcopy(dict_res[k])
            for d in dict_log_pos[k]:
                if "message" in d:
                    del d["message"]
                if "vector" in d:
                    del d["vector"]
        json.dump(dict_log_pos, f, ensure_ascii= False)
    return dict_res


def get_lang(docs):
    lang_vi = 0
    lang_en = 0
    dict_lang = {}
    for d in docs:
        lang = d.get("lang", "")
        if lang not in dict_lang:
            dict_lang[lang] = 0
        dict_lang[lang] += 1
        # if d.get("lang", "") == "vi":
        #     lang_vi += 1
        # else:
        #     lang_en += 1
    lst_lang = []
    lst_cnt = []
    for k in dict_lang:
        lst_lang.append(k)
        lst_cnt.append(dict_lang[k])
    idx_max = np.argsort(np.array(lst_cnt))[::-1][0]
    lang = lst_lang[int(idx_max)]
    
    if lang.startswith("zh_"):
        lang = "zh"
    print("lang: ", lang, lst_cnt[int(idx_max)])
    return lang


def topic_clustering(docs, distance_threshold, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, 
                        delete_message=True, prompt="", type_cluster:str = "single", hash_str: str= "", id_topic=""):
    # global model, model_en
    with open("data/topic_name.txt") as f:
        dict_topic_name = json.load(f)
    topic_name_relevant = dict_topic_name.get(id_topic , "")
    docs = docs[:30000]
    lang = get_lang(docs)
    if type_cluster == "complete" and lang == "zh":
        distance_threshold = 0.4
    if type_cluster == "complete" and lang == "en":
        distance_threshold = 0.4
        # type_cluster = "single"

    result = {}
    cluster_score = {}
    cluster_real_vectors = {}
    # docs = segment(docs, lang=lang)

    t1 = time.time()
    if len(docs) < 1:
        return result
    elif len(docs) == 1:
        return {
            "0": docs
        }
    vec_prompt = []
    prompt_strips = []
    # prompt = ""
    if topic_name_relevant:
        prompt_split = topic_name_relevant.split("#####")
        for prom in prompt_split:
            sys_p = prom.strip().split("$$$$")
            if len(sys_p) == 1:
                prompt_strips.append(prom.strip())
            else:
                prompt_strips.append(sys_p[1].strip())
        if lang == "zh":
            vec_prompt = embbeded_zh(prompt_split)
        elif lang == "en":
            vec_prompt = embbeded_en(prompt_split)
        else:
            vec_prompt = inference.encode(prompt_split, lang=lang)
    if lang == "zh":
        features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
        vectors = embbeded_zh(features)
        # vectors = embedded_bge(features)
        if len(vectors) == 0:
            print(f"[WARNING] Embedded {lang}: {len(vectors)} / {len(features)}")
            vectors = inference.encode(features, lang=lang)
        # vectors = model.encode(features, show_progress_bar=False)
    elif lang == "en":
        features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
        vectors = embbeded_en(features)
        # vectors = embedded_bge(features)
        if len(vectors) == 0:
            print(f"[WARNING] Embedded {lang}: {len(vectors)} / {len(features)}")
            vectors = inference.encode(features, lang=lang)
    else:
        features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
        # vectors = embedded_bge(features)
        # if len(vectors) == 0:
        #     print(f"[WARNING] Embedded {lang}: {len(vectors)} / {len(features)}")
        vectors = inference.encode(features, lang=lang)
        # vectors = model_en.encode(features, show_progress_bar=False)
    clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
                                        linkage=type_cluster, distance_threshold=distance_threshold)
    clusteror.fit(vectors)
    matrix_vec = np.stack(vectors, axis=0)
    print(f"Time encode + clustering: {time.time() - t1} {clusteror.n_clusters_}")
    for i in range(clusteror.n_clusters_):
        result[str(i + 1)] = []
        cluster_score[str(i + 1)] = 0
        ids = clusteror.labels_ # == i
        # cluster_real_vectors[str(i + 1)] = re_clustering(ids, matrix_vec, distance_threshold, max_doc_per_cluster)

    for i in range(len(clusteror.labels_)):
        cluster_no = clusteror.labels_[i]
        # if any((cluster_real_vectors[str(cluster_no+1)][:] == vectors[i]).all(1)):
        if docs[i].get('domain','') not in ["cungcau.vn","baomoi.com","news.skydoor.net"]:
            response_doc = {}
            response_doc = docs[i]
            score = response_doc.get('score', 0)
            if not docs[i].get('message','').strip():
                continue 
            if score > cluster_score[str(cluster_no + 1)]:
                cluster_score[str(cluster_no + 1)] = score
            if 'domain' in docs[i]:
                response_doc['domain'] = docs[i]['domain']
            if 'url' in docs[i]:
                response_doc['url'] = docs[i]['url']
            if 'title' in docs[i]:
                response_doc['title'] = clean_text(docs[i]['title'])
            if 'snippet' in docs[i]:
                response_doc['snippet'] = clean_text(docs[i]['snippet'])
            if 'created_time' in docs[i]:
                response_doc['created_time'] = docs[i]['created_time']
            if "sentiment" in docs[i]:
                response_doc['sentiment'] = docs[i]['sentiment']
            if 'message' in docs[i]:
                title = docs[i].get('title','')
                snippet = docs[i].get('snippet','')
                message = docs[i].get('message','')
                if title.strip():
                    split_mess = message.split(title)
                    if len(split_mess) > 1:
                        message = title.join(split_mess[1:])
                if snippet.strip():
                    split_mess = message.split(snippet)
                    if len(split_mess) > 1:
                        message = snippet.join(split_mess[1:])

                response_doc['message'] = clean_text(message)
            if 'id' in docs[i]:
                response_doc['id'] = docs[i]['id']
            # response_doc['score'] = 0.0
            response_doc['title_summarize'] = []
            response_doc['content_summary'] = ""
            response_doc['total_facebook_viral'] = 0
            response_doc["vector"] = np.array(vectors[i]).tolist()
            result[str(cluster_no + 1)].append(response_doc)
    empty_clus_ids = []
    for x in result:
        result[x] = sorted(result[x], key=lambda i: -len(i.get('message','')))
        if len( result[x]) > 0:
            if len(result[x]) > 1:
                result[x] = check_duplicate_title_domain(result[x])
            result[x][0]['num_docs'] = len(result[x])
            result[x][0]['max_score'] = cluster_score[x]
        else:
            empty_clus_ids.append(x)
    
    for x in empty_clus_ids:
        result.pop(x,None)
    # result = dict(sorted(result.items(), key=lambda i: -len(i[1])))[:top_cluster]
    with open(f"log_llm/topic_result_before_postprocessing/{hash_str}.json", "w") as f:
        dict_log = {}
        for k in result:
            dict_log[k] = copy.deepcopy(result[k])
            for d in dict_log[k]:
                if "message" in d:
                    del d["message"]
                if "vector" in d:
                    del d["vector"]
        json.dump(dict_log, f, ensure_ascii= False)
    return post_processing(result, top_cluster=top_cluster, top_sentence=top_sentence, topn_summary=topn_summary, sorted_field = sorted_field, max_doc_per_cluster=max_doc_per_cluster, delete_message=delete_message,
                            prompt=topic_name_relevant, hash_str=hash_str, vectors_prompt=vec_prompt)

def check_duplicate_title_domain(docs):
    lst_title_domain = [f"{d.get('domain', '')} {d.get('title','')}" for d in docs]
    for i in range(1,len(lst_title_domain) -1):
        for j in range(i+1,len(lst_title_domain)):
            if lst_title_domain[j] == lst_title_domain[i]:
                lst_title_domain[j] = 'dup'
    lst_filter_docs = [docs[i] for i,x in enumerate(lst_title_domain) if x != 'dup']
    return lst_filter_docs
def convert_date(text):
    text = text.replace(".", "/")
    text = text.replace("-", "/")
    return text


def check_keyword(sentence):
    keyword = ['sáng', 'trưa', 'chiều', 'tối', 'đến', 'hôm', 'ngày', 'tới']
    for k in keyword:
        if k in sentence:
            return True
    return False


def extract_events_and_time(docs, publish_date):
    def standardize(date_str):
        return date_str.replace('.', '/').replace('-', '/')

    def add_0(date_str):

        date_str = date_str.split('/')
        res = []
        for o in date_str:
            o = re.sub('\s+', '', o)
            if len(o) < 2:
                o = '0' + o
            res.append(o)
        date_str = '/'.join(res)
        return date_str

    def get_date_list(reg, sentence):
        find_object = re.finditer(reg, sentence)
        date_list = [x.group() for x in find_object]
        return date_list

    year = publish_date.split('/')[2]

    # dd/mm/yyyy
    reg_exp_1 = '(\D|^)(?:0?[1-9]|[12][0-9]|3[01])[- \/.](?:0?[1-9]|1[012])[- \/.]([12]([0-9]){3})(\D|$)'
    # #mm/yyyy
    # reg_exp_5 = '(\D|^)(?:0?[1-9]|1[012])[- \/.]([12]([0-9]){3})(\D|$)'
    # dd/mm
    reg_exp_2 = '(\D|^)(?:0?[1-9]|[12][0-9]|3[01])[- \/.](?:0?[1-9]|1[012])(\D|$)'

    # ngày  dd tháng mm năm yyyy
    reg_exp_3 = '(ngày)\s*\d{1,2}\s*(tháng)\s*\d{1,2}\s*(năm)\s*\d{4}'
    # ngày dd tháng mm
    reg_exp_4 = '(ngày)\s*\d{1,2}\s*(tháng)\s*\d{1,2}'

    result = []
    for d in docs:
        text = d['message']
        for sentence in sent_tokenize(text):
            lower_sentence = sentence.lower()
            c = re.search(reg_exp_3, sentence.lower())
            d = re.search(reg_exp_4, sentence.lower())
            # e = re.search(reg_exp_5, sentence.lower())
            a = re.search(reg_exp_1, sentence)
            b = re.search(reg_exp_2, sentence)
            #
            if (a or b or c or d) and check_keyword(lower_sentence):
                date_list = get_date_list(reg_exp_1, lower_sentence)
                date_entity = ''
                if date_list:
                    date_entity = add_0(standardize(date_list[0]))
                elif get_date_list(reg_exp_2, lower_sentence):
                    date_list = get_date_list(reg_exp_2, lower_sentence)
                    date_entity = add_0(standardize(date_list[0]) + '/' + year)
                elif get_date_list(reg_exp_3, lower_sentence):
                    date_list = get_date_list(reg_exp_3, lower_sentence)

                    date_entity = date_list[0].replace('ngày', '').replace('tháng', '').replace('năm', '').strip()
                    date_entity = re.sub('\s+', ' ', date_entity)
                    date_entity = date_entity.replace(' ', '/')
                    date_entity = add_0(date_entity)
                else:
                    date_list = get_date_list(reg_exp_4, lower_sentence)
                    if date_list != []:
                        date_entity = date_list[0].replace('ngày', '').replace('tháng', '').replace('năm', '').strip()
                        date_entity = re.sub('\s+', ' ', date_entity)
                        date_entity = date_entity.replace(' ', '/')
                        date_entity = date_entity + '/' + year
                        date_entity = add_0(date_entity)
                result.append((sentence, date_entity))
    return result

def find_index_nearest_vector(cluster, vectors):
    # Compute the centroid of the cluster
    centroid = np.mean(cluster, axis=0, keepdims=True)
    
    # Calculate the Euclidean distance between each vector and the centroid
    distances = cosine_similarity(centroid, vectors)
    
    # Find the index of the vector with the minimum distance
    nearest_index = np.argmin(distances, axis=1)
    
    
    return nearest_index

def re_clustering(ids, vectors, distance_threshold, max_doc_per_cluster):
    sub_vectors = vectors[ids]

    try:
        if sub_vectors.shape[0] < 2:
            return sub_vectors
        sub_clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
                        linkage='complete', distance_threshold=0.12)
        sub_clusteror.fit(sub_vectors)
        dict_cluster = {id_clus: sub_vectors[sub_clusteror.labels_ == id_clus] for id_clus in range(sub_clusteror.n_clusters_)}
        dict_num_vec = {id_clus: v.shape[0] for id_clus, v in dict_cluster.items()}

        max_num_cluster = max(dict_num_vec, key=dict_num_vec.get)
        other_vectors = sub_vectors[sub_clusteror.labels_ != max_num_cluster]

        # if other_vectors.shape[0]:
        #     while dict_num_vec[max_num_cluster] < max_doc_per_cluster:
        #         tmp_index_vec = find_index_nearest_vector(dict_cluster[max_num_cluster], other_vectors)
        #         dict_cluster[max_num_cluster] = np.vstack((dict_cluster[max_num_cluster], other_vectors[tmp_index_vec]))
        #         dict_num_vec[max_num_cluster] += 1
        #         if other_vectors.shape[0] != 1:
        #             other_vectors = np.delete(other_vectors, tmp_index_vec, axis=0)
        #         else:
        #             break
        cosine_scores = cosine_similarity(dict_cluster[max_num_cluster], dict_cluster[max_num_cluster])
        with open("/home/vietle/topic-clustering/log_score.txt", "a") as f:
            f.write(str(cosine_scores) + "\n")
        return dict_cluster[max_num_cluster]
    except Exception as e:
        with open("/home/vietle/topic-clustering/log_clustering_diemtin/log_cluster_second.txt", "a") as f:
            f.write(str(e)+"$$"+json.dumps({"ids": ids.tolist(), "vectors": vectors.tolist()}))
        return sub_vectors