from typing import Dict, Any

import numpy as np
import spacy
from PIL import ImageFont

from spacy.tokens import Doc


def get_pil_text_size(text, font_size, font_name):
    font = ImageFont.truetype(font_name, font_size)
    size = font.getsize(text)
    return size


def render_arrow(
        label: str, start: int, end: int, direction: str, i: int
) -> str:
    """Render individual arrow.

    label (str): Dependency label.
    start (int): Index of start word.
    end (int): Index of end word.
    direction (str): Arrow direction, 'left' or 'right'.
    i (int): Unique ID, typically arrow index.
    RETURNS (str): Rendered SVG markup.
    """
    TPL_DEP_ARCS = """
    <g class="displacy-arrow">
        <path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="red"/>
        <text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
            <textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="red" text-anchor="middle">{label}</textPath>
        </text>
        <path class="displacy-arrowhead" d="{head}" fill="red"/>
    </g>
    """
    arc = get_arc(start + 10, 50, 5, end + 10)
    arrowhead = get_arrowhead(direction, start + 10, 50, end + 10)
    label_side = "right" if direction == "rtl" else "left"
    return TPL_DEP_ARCS.format(
        id=0,
        i=0,
        stroke=2,
        head=arrowhead,
        label=label,
        label_side=label_side,
        arc=arc,
    )


def get_arc(x_start: int, y: int, y_curve: int, x_end: int) -> str:
    """Render individual arc.

    x_start (int): X-coordinate of arrow start point.
    y (int): Y-coordinate of arrow start and end point.
    y_curve (int): Y-corrdinate of Cubic Bézier y_curve point.
    x_end (int): X-coordinate of arrow end point.
    RETURNS (str): Definition of the arc path ('d' attribute).
    """
    template = "M{x},{y} C{x},{c} {e},{c} {e},{y}"
    return template.format(x=x_start, y=y, c=y_curve, e=x_end)


def get_arrowhead(direction: str, x: int, y: int, end: int) -> str:
    """Render individual arrow head.

    direction (str): Arrow direction, 'left' or 'right'.
    x (int): X-coordinate of arrow start point.
    y (int): Y-coordinate of arrow start and end point.
    end (int): X-coordinate of arrow end point.
    RETURNS (str): Definition of the arrow head path ('d' attribute).
    """
    arrow_width = 6
    if direction == "left":
        p1, p2, p3 = (x, x - arrow_width + 2, x + arrow_width - 2)
    else:
        p1, p2, p3 = (end, end + arrow_width - 2, end - arrow_width + 2)
    return f"M{p1},{y + 2} L{p2},{y - arrow_width} {p3},{y - arrow_width}"


# parsed = [{'words': [{'text': 'The', 'tag': 'DET', 'lemma': None}, {'text': 'OnePlus', 'tag': 'PROPN', 'lemma': None}, {'text': '10', 'tag': 'NUM', 'lemma': None}, {'text': 'Pro', 'tag': 'PROPN', 'lemma': None}, {'text': 'is', 'tag': 'AUX', 'lemma': None}, {'text': 'the', 'tag': 'DET', 'lemma': None}, {'text': 'company', 'tag': 'NOUN', 'lemma': None}, {'text': "'s", 'tag': 'PART', 'lemma': None}, {'text': 'first', 'tag': 'ADJ', 'lemma': None}, {'text': 'flagship', 'tag': 'NOUN', 'lemma': None}, {'text': 'phone.', 'tag': 'NOUN', 'lemma': None}], 'arcs': [{'start': 0, 'end': 3, 'label': 'det', 'dir': 'left'}, {'start': 1, 'end': 3, 'label': 'nmod', 'dir': 'left'}, {'start': 1, 'end': 2, 'label': 'nummod', 'dir': 'right'}, {'start': 3, 'end': 4, 'label': 'nsubj', 'dir': 'left'}, {'start': 5, 'end': 6, 'label': 'det', 'dir': 'left'}, {'start': 6, 'end': 10, 'label': 'poss', 'dir': 'left'}, {'start': 6, 'end': 7, 'label': 'case', 'dir': 'right'}, {'start': 8, 'end': 10, 'label': 'amod', 'dir': 'left'}, {'start': 9, 'end': 10, 'label': 'compound', 'dir': 'left'}, {'start': 4, 'end': 10, 'label': 'attr', 'dir': 'right'}], 'settings': {'lang': 'en', 'direction': 'ltr'}}]
def render_sentence_custom(unmatched_list: Dict):
    TPL_DEP_WORDS = """
  <text class="displacy-token" fill="currentColor" text-anchor="start" y="{y}">
      <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
      <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
  </text>
  """

    TPL_DEP_SVG = """
  <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
  """
    arcs_svg = []
    nlp = spacy.load('en_core_web_lg')
    doc = nlp(unmatched_list["sentence"])
    # words = {}
    # unmatched_list = [parse_deps(doc)]
    # #print(parsed)
    # for i, p in enumerate(unmatched_list):
    #     arcs = p["arcs"]
    #     words = p["words"]
    # for i, a in enumerate(arcs):
    #     #CHECK CERTAIN DEPS (ALSO ADD/CHANGE BELOW WHEN CHANGING HERE)
    #     if a["label"] == "amod":
    #         couples = (a["start"], a["end"])
    #     elif a["label"] == "pobj":
    #         couples = (a["start"], a["end"])
    # #couples = (3,5)
    #
    # x_value_counter = 10
    # index_counter = 0
    # svg_words = []
    # coords_test = []
    # for i, word in enumerate(words):
    #     word = word["text"]
    #     word = word + " "
    #     pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0]
    #     svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70))
    #     if index_counter >= couples[0] and index_counter <= couples[1]:
    #         coords_test.append(x_value_counter)
    #         x_value_counter += 50
    #     index_counter += 1
    #     x_value_counter += pixel_x_length + 4
    # for i, a in enumerate(arcs):
    #     if a["label"] == "amod":
    #         arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i))
    #     elif a["label"] == "pobj":
    #         arcs_svg.append(render_arrow(a["label"], coords_test[0], coords_test[-1], a["dir"], i))
    #
    # content = "".join(svg_words) + "".join(arcs_svg)
    #
    # full_svg = TPL_DEP_SVG.format(
    #     id=0,
    #     width=1200, #600
    #     height=250, #125
    #     color="#00000",
    #     bg="#ffffff",
    #     font="Arial",
    #     content=content,
    #     dir="ltr",
    #     lang="en",
    # )

    x_value_counter = 10
    index_counter = 0
    svg_words = []
    words = unmatched_list["sentence"].split(" ")
    coords_test = []
    #print(unmatched_list)
    #print(words)
    #print("NOW")
    direction_current = "rtl"
    if unmatched_list["cur_word_index"] < unmatched_list["target_word_index"]:
        min_index = unmatched_list["cur_word_index"]
        max_index = unmatched_list["target_word_index"]
        direction_current = "left"
    else:
        max_index = unmatched_list["cur_word_index"]
        min_index = unmatched_list["target_word_index"]
    for i, token in enumerate(doc):
        word = str(token)
        word = word + " "
        pixel_x_length = get_pil_text_size(word, 16, 'arial.ttf')[0]
        svg_words.append(TPL_DEP_WORDS.format(text=word, tag="", x=x_value_counter, y=70))
        if min_index <= index_counter <= max_index:
            coords_test.append(x_value_counter)
            if index_counter < max_index - 1:
                x_value_counter += 50
        index_counter += 1
        x_value_counter += pixel_x_length + 4

    # TODO: DYNAMIC DIRECTION MAKING (SHOULD GIVE WITH DICT I THINK)
    #print(coords_test)
    arcs_svg.append(render_arrow(unmatched_list['dep'], coords_test[0], coords_test[-1], direction_current, i))

    content = "".join(svg_words) + "".join(arcs_svg)

    full_svg = TPL_DEP_SVG.format(
        id=0,
        width=1200,  # 600
        height=75,  # 125
        color="#00000",
        bg="#ffffff",
        font="Arial",
        content=content,
        dir="ltr",
        lang="en",
    )
    return full_svg


def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
    """Generate dependency parse in {'words': [], 'arcs': []} format.

    doc (Doc): Document do parse.
    RETURNS (dict): Generated dependency parse keyed by words and arcs.
    """
    doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data"]))
    if not doc.has_annotation("DEP"):
        print("WARNING")
    if options.get("collapse_phrases", False):
        with doc.retokenize() as retokenizer:
            for np in list(doc.noun_chunks):
                attrs = {
                    "tag": np.root.tag_,
                    "lemma": np.root.lemma_,
                    "ent_type": np.root.ent_type_,
                }
                retokenizer.merge(np, attrs=attrs)
    if options.get("collapse_punct", True):
        spans = []
        for word in doc[:-1]:
            if word.is_punct or not word.nbor(1).is_punct:
                continue
            start = word.i
            end = word.i + 1
            while end < len(doc) and doc[end].is_punct:
                end += 1
            span = doc[start:end]
            spans.append((span, word.tag_, word.lemma_, word.ent_type_))
        with doc.retokenize() as retokenizer:
            for span, tag, lemma, ent_type in spans:
                attrs = {"tag": tag, "lemma": lemma, "ent_type": ent_type}
                retokenizer.merge(span, attrs=attrs)
    fine_grained = options.get("fine_grained")
    add_lemma = options.get("add_lemma")
    words = [
        {
            "text": w.text,
            "tag": w.tag_ if fine_grained else w.pos_,
            "lemma": w.lemma_ if add_lemma else None,
        }
        for w in doc
    ]
    arcs = []
    for word in doc:
        if word.i < word.head.i:
            arcs.append(
                {"start": word.i, "end": word.head.i, "label": word.dep_, "dir": "left"}
            )
        elif word.i > word.head.i:
            arcs.append(
                {
                    "start": word.head.i,
                    "end": word.i,
                    "label": word.dep_,
                    "dir": "right",
                }
            )
    return {"words": words, "arcs": arcs, "settings": get_doc_settings(orig_doc)}


def get_doc_settings(doc: Doc) -> Dict[str, Any]:
    return {
        "lang": doc.lang_,
        "direction": doc.vocab.writing_system.get("direction", "ltr"),
    }