import json
import time
from .utils import get_sbert_embedding, clean_text
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import requests
# from clean_text import normalize_text

MAX_LENGTH_FEATURE = 250
MIN_LENGTH_FEATURE = 100
URL_CHECK_SPAM = "http://10.9.3.70:30036/predict"

def check_spam(docs):
    json_body = {
        "domain_id": "",
        "records": [
            {
            "text": doc.get("message",""),
            "idxcol": 1
            } for doc in docs
        ] 
    }

    result = requests.post(URL_CHECK_SPAM, json = json_body).json()
    docs = [x for i,x in enumerate(docs) if result[i]["label"] == 0]
    return docs

def preocess_feature(doc):
    message = doc.get("message","")
    paras = message.split("\n") 
    feature = ""
    paras = [clean_text(x.strip(), normalize=False) for x in paras if x.strip() and len(x.strip()) > 10]
    for para in paras:
        if len(feature) + len(para) < MAX_LENGTH_FEATURE:
            feature += " " +para
        elif len(feature) < MIN_LENGTH_FEATURE:
            sens = sent_tokenize(para)
            for sen in sens:
                if len(feature) + len(sen) < MAX_LENGTH_FEATURE or len(feature.strip()) < MIN_LENGTH_FEATURE:
                    feature += " " +sen
    return feature

def topic_clustering(docs, distance_threshold, top_cluster=5, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, delete_message=True, is_check_spam = True):
    # global model, model_en
    
    docs = [x for x in docs if len(x.get("message","")) > 100]
    docs = docs[:30000]
    if is_check_spam:
        docs = check_spam(docs)
    result = {}
    cluster_score = {}

    t1 = time.time()
    if len(docs) < 1:
        return result
    elif len(docs) == 1:
        return {
            "0": docs
        }

    # features = [doc.get('title', "") + ". " + doc.get('snippet', "") for doc in docs]
    
    f_docs = []
    for x in docs:
        ft = preocess_feature(x)
        if len(ft) > MIN_LENGTH_FEATURE:
            x["title"] = ft
            f_docs.append(x)
    docs = f_docs

    features = [x["title"] for x in docs ]
    # with open("feature", 'w') as f:
    #     json.dump(features, f, ensure_ascii = False)
    # print(features)
    vectors = get_sbert_embedding(features)

    clusteror = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='cosine',
                                        linkage='complete', distance_threshold=distance_threshold)
    clusteror.fit(vectors)
    print(f"Time encode + clustering: {time.time() - t1} {clusteror.n_clusters_}")
    for i in range(clusteror.n_clusters_):
        result[str(i + 1)] = []
        cluster_score[str(i + 1)] = 0
    for i in range(len(clusteror.labels_)):
        cluster_no = clusteror.labels_[i]   
        if docs[i].get('domain','') not in ["cungcau.vn","baomoi.com","news.skydoor.net"]:
            response_doc = {}
            response_doc = docs[i]
            score = response_doc.get('score', 0)
            if not docs[i].get('message','').strip():
                continue 
            if score > cluster_score[str(cluster_no + 1)]:
                cluster_score[str(cluster_no + 1)] = score
            if 'domain' in docs[i]:
                response_doc['domain'] = docs[i]['domain']
            if 'url' in docs[i]:
                response_doc['url'] = docs[i]['url']
            if 'title' in docs[i]:
                response_doc['title'] = clean_text(docs[i]['title'])
            if 'snippet' in docs[i]:
                response_doc['snippet'] = clean_text(docs[i]['snippet'])
            if 'created_time' in docs[i]:
                response_doc['created_time'] = docs[i]['created_time']
            if "sentiment" in docs[i]:
                response_doc['sentiment'] = docs[i]['sentiment']
            if 'message' in docs[i]:
                title = docs[i].get('title','')
                snippet = docs[i].get('snippet','')
                message = docs[i].get('message','')
                # if title.strip():
                #     split_mess = message.split(title)
                #     if len(split_mess) > 1:
                #         message = title.join(split_mess[1:])
                # if snippet.strip():
                #     split_mess = message.split(snippet)
                #     if len(split_mess) > 1:
                #         message = snippet.join(split_mess[1:])

                response_doc['message'] = clean_text(message)
            if 'id' in docs[i]:
                response_doc['id'] = docs[i]['id']
            # response_doc['score'] = 0.0

            # response_doc['title_summarize'] = []
            # response_doc['content_summary'] = ""
            # response_doc['total_facebook_viral'] = 0
            result[str(cluster_no + 1)].append(response_doc)
    
    empty_clus_ids = []
    for x in result:
        result[x] = sorted(result[x], key=lambda i: -len(i.get('message','')))
        if len( result[x]) > 0:
            # if len(result[x]) > 1:
            #     result[x] = check_duplicate_title_domain(result[x])
            result[x][0]['num_docs'] = len(result[x])
            result[x][0]['max_score'] = cluster_score[x]
        else:
            empty_clus_ids.append(x)
    
    for x in empty_clus_ids:
        result.pop(x,None)
    
    result = dict( sorted(result.items(), key=lambda i: -len(i[1]))[:top_cluster])
    return result
    # return post_processing(result, top_cluster=top_cluster, top_sentence=top_sentence, topn_summary=topn_summary, sorted_field = sorted_field, max_doc_per_cluster=max_doc_per_cluster, delete_message=delete_message)

if __name__ == '__main__':
    # with open("/home2/vietle/DA-Report/social.json", 'r') as f:
    #     docs = json.load(f)[:2000]
    with open("/home2/vietle/news-cms/topic_summarization/data/news_cms.social.json", 'r') as f:
        docs = json.load(f)[:10000]
    clusters = topic_clustering(docs, distance_threshold=0.2, top_cluster=5000, top_sentence=5, topn_summary=5, sorted_field='', max_doc_per_cluster=50, delete_message=False)
    with open("/home2/vietle/news-cms/topic_summarization/cluster/news_cms.social.json", 'w') as f:

        json.dump(clusters,f, ensure_ascii =False)