|
import json
|
|
import os
|
|
|
|
|
|
|
|
from function.topic_clustering import model, AgglomerativeClustering
|
|
|
|
def check_duplicate_title_domain(docs):
|
|
lst_title_domain = [f"{d.get('domain', '')} {d.get('title','')}" for d in docs]
|
|
for i in range(1,len(lst_title_domain) -1):
|
|
for j in range(i+1,len(lst_title_domain)):
|
|
if lst_title_domain[j] == lst_title_domain[i]:
|
|
lst_title_domain[j] = 'dup'
|
|
lst_filter_docs = [docs[i] for i,x in enumerate(lst_title_domain) if x != 'dup']
|
|
return lst_filter_docs
|
|
|
|
def main(req):
|
|
|
|
type = req['type']
|
|
if type == 'monthly':
|
|
MAX_CLUSTER = 50
|
|
else:
|
|
MAX_CLUSTER = 20
|
|
|
|
MAX_NUM_DOC_PER_CLUSTER = 50
|
|
|
|
threshold = 0.4
|
|
|
|
preprocess = req.get('preprocess', [])
|
|
lst_labels = []
|
|
lst_topics = []
|
|
for date_clusters in preprocess:
|
|
|
|
topic = date_clusters.get('topic', [])
|
|
if topic:
|
|
for topic_id in topic:
|
|
|
|
topic_docs = topic[topic_id]
|
|
lst_topics.append(topic[topic_id])
|
|
label = '. '.join([topic_docs[0].get('title',''),topic_docs[0].get('snippet','')])
|
|
lst_labels.append(label)
|
|
|
|
final_clusters = []
|
|
label_clusters = sbert_clustering(lst_labels, distance_threshold=threshold,return_ids=True)
|
|
|
|
|
|
print(label_clusters)
|
|
|
|
if label_clusters:
|
|
for id_label_clusters in label_clusters:
|
|
merge_clusters = []
|
|
num_docs = 0
|
|
for topic_id in label_clusters[id_label_clusters]:
|
|
topic = lst_topics[topic_id]
|
|
count_doc = topic[0].get('num_docs',1)
|
|
num_docs += count_doc
|
|
merge_clusters.extend(lst_topics[topic_id])
|
|
|
|
merge_clusters = sorted(merge_clusters, key=lambda x: -x.get('created_time',0))
|
|
merge_clusters = check_duplicate_title_domain(merge_clusters)
|
|
|
|
merge_clusters = merge_clusters[:MAX_NUM_DOC_PER_CLUSTER]
|
|
for doc in merge_clusters:
|
|
doc['num_docs'] = num_docs
|
|
final_clusters.append(merge_clusters)
|
|
|
|
final_clusters = sorted(final_clusters, key=lambda x: -x[0]['num_docs'])
|
|
final_clusters = final_clusters[:MAX_CLUSTER]
|
|
|
|
final_result = {}
|
|
for i,cluster in enumerate(final_clusters):
|
|
final_result[i] = cluster
|
|
with open('zzz.json','w') as f:
|
|
json.dump(final_result, f, ensure_ascii=False)
|
|
return final_result
|
|
|
|
def get_sbert_embedding(lst_sentence):
|
|
embs = model.encode(lst_sentence)
|
|
|
|
return embs
|
|
|
|
def sbert_clustering(lst_sentence, distance_threshold=0.25, return_ids = False):
|
|
lst_sentence = [sen.replace('_',' ') for sen in lst_sentence]
|
|
if len(lst_sentence) == 0:
|
|
return
|
|
if len(lst_sentence) == 1:
|
|
if return_ids:
|
|
return {
|
|
0: [0]
|
|
}
|
|
return {
|
|
0: lst_sentence
|
|
}
|
|
|
|
|
|
embs = get_sbert_embedding(lst_sentence)
|
|
|
|
hyer_clusteror = AgglomerativeClustering(n_clusters = None,compute_full_tree = True, affinity = 'cosine',
|
|
linkage = 'complete', distance_threshold=distance_threshold)
|
|
|
|
|
|
hyer_clusteror.fit(embs)
|
|
|
|
|
|
dict_result = {}
|
|
dict_ids = {}
|
|
for i in range(hyer_clusteror.n_clusters_):
|
|
if i not in dict_result:
|
|
dict_result[i] = []
|
|
dict_ids[i] = []
|
|
for j in range(len(lst_sentence)):
|
|
if hyer_clusteror.labels_[j] == i:
|
|
dict_result[i].append(lst_sentence[j])
|
|
dict_ids[i].append(j)
|
|
|
|
if return_ids:
|
|
output = dict_ids
|
|
else:
|
|
output = dict_result
|
|
result = dict(sorted(output.items(), key=lambda i: -len(i[1])))
|
|
return result
|
|
|
|
if __name__ == '__main__':
|
|
with open("input_merge.json",'r') as f:
|
|
req = json.load(f)
|
|
main(req) |