from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
import spacy
from spacy import displacy

import streamlit as st

DEFAULT_LABEL_COLORS = {
    "ORG": "#7aecec",
    "PRODUCT": "#bfeeb7",
    "GPE": "#feca74",
    "LOC": "#ff9561",
    "PERSON": "#aa9cfc",
    "NORP": "#c887fb",
    "FACILITY": "#9cc9cc",
    "EVENT": "#ffeb80",
    "LAW": "#ff8197",
    "LANGUAGE": "#ff8197",
    "WORK_OF_ART": "#f0d0ff",
    "DATE": "#bfe1d9",
    "TIME": "#bfe1d9",
    "MONEY": "#e4e7d2",
    "QUANTITY": "#e4e7d2",
    "ORDINAL": "#e4e7d2",
    "CARDINAL": "#e4e7d2",
    "PERCENT": "#e4e7d2",
}


@st.experimental_singleton(max_entries=1)
def get_pipeline():
    triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
    return triplet_extractor



@st.experimental_singleton(max_entries=1)
def load_spacy():
    nlp = spacy.load("en_core_web_sm")
    return nlp


def generate_knowledge_graph(texts: List[str], filename: str):
    nlp = load_spacy()

    doc = nlp("\n".join(texts).lower())
    NERs = [ent.text for ent in doc.ents]
    NER_types =  [ent.label_ for ent in doc.ents]

    triplets = []
    for triplet in texts:
        triplets.extend(generate_partial_graph(triplet))
    heads = [ t["head"].lower() for t in triplets]
    tails = [ t["tail"].lower() for t in triplets]

    nodes = list(set(heads + tails))
    net = Network(directed=True, width="700px", height="700px")

    for n in nodes:
        if n in NERs:
            NER_type = NER_types[NERs.index(n)]
            if NER_type in NER_types:
                if NER_type in DEFAULT_LABEL_COLORS.keys():
                    color = DEFAULT_LABEL_COLORS[NER_type]
                else:
                    color = "#666666"
                net.add_node(n, title=NER_type, shape="circle", color=color)
            else:
                net.add_node(n, shape="circle")
        else:
            net.add_node(n, shape="circle")

    unique_triplets = set()
    stringify_trip = lambda x : x["tail"] + x["head"] + x["type"].lower()
    for triplet in triplets:
        if stringify_trip(triplet) not in unique_triplets:
            net.add_edge(triplet["head"].lower(), triplet["tail"].lower(),
                         title=triplet["type"], label=triplet["type"])
            unique_triplets.add(stringify_trip(triplet))

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)
    return nodes


@lru_cache(maxsize=16)
def generate_partial_graph(text: str):
    triplet_extractor = get_pipeline()
    a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
    extracted_text = triplet_extractor.tokenizer.batch_decode(a)
    extracted_triplets = extract_triplets(extracted_text[0])
    return extracted_triplets


def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets