# ###########################################################################
#
#  CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
#  (C) Cloudera, Inc. 2022
#  All rights reserved.
#
#  Applicable Open Source License: Apache 2.0
#
#  NOTE: Cloudera open source products are modular software products
#  made up of hundreds of individual components, each of which was
#  individually copyrighted.  Each Cloudera open source product is a
#  collective work under U.S. Copyright Law. Your license to use the
#  collective work is as provided in your written agreement with
#  Cloudera.  Used apart from the collective work, this file is
#  licensed for your use pursuant to the open source license
#  identified above.
#
#  This code is provided to you pursuant a written agreement with
#  (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
#  this code. If you do not have a written agreement with Cloudera nor
#  with an authorized and properly licensed third party, you do not
#  have any rights to access nor to use this code.
#
#  Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
#  contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
#  KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
#  WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
#  IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
#  FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
#  AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
#  ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
#  OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
#  CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
#  RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
#  BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
#  DATA.
#
# ###########################################################################

from typing import Iterable

import altair as alt
from captum.attr._utils.visualization import (
    VisualizationDataRecord,
    format_word_importances,
    _get_color,
)

try:
    from IPython.display import display, HTML

    HAS_IPYTHON = True
except ImportError:
    HAS_IPYTHON = False

def format_classname(classname):
    return f'<td>{classname}</td>'

def visualize_text(
    datarecords: Iterable[VisualizationDataRecord], legend: bool = True
) -> "HTML":  # In quotes because this type doesn't exist in standalone mode
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )

    dom = []
    dom.append(
        '<head><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"></head>'
    )
    dom.append("""<table width:100; class="table">""")
    rows = [
        "<thead>"
        "<tr>"
        "<th scope='col'><span class='text-nowrap'>Predicted Label</span></th>"
        "<th scope='col'><span class='text-nowrap'>Attribution Score</span></th>"
        "<th scope='col'><span class='text-nowrap'>Feature Importance</span></th>"
        "</tr>"
        "</thead>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tbody>",
                    "<tr>",
                    format_classname(
                        f"{datarecord.pred_class.capitalize()}"
                    ),
                    format_classname(f"{round(datarecord.attr_score.item(), 2)}"),
                    format_word_importances(
                        datarecord.raw_input_ids, datarecord.word_attributions
                    ),
                    "<tr>",
                    "</tbody>",
                ]
            )
        )

    dom.append("".join(rows))
    dom.append("</table>")

    if legend:
        dom.append("<div class='row'>")
        dom.append("<div class='col-6'>")
        dom.append("<b>Legend: </b>")

        for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
            dom.append(
                '<span style="display: inline-block; width: 10px; height: 10px; \
                border: 1px solid; background-color: \
                {value}"></span> {label}  '.format(
                    value=_get_color(value), label=label
                )
            )
        dom.append("</div>")
        dom.append("<div class='col-6'></div>")

        dom.append("</div>")

    html = HTML("".join(dom))
    display(html)

    return html


def build_altair_classification_plot(format_cls_result):
    """
    Builds Altair bar chart for classification results.

    Args:
        format_cls_result (List): Output from `format_classification_results()`
    """
    source = alt.pd.DataFrame(format_cls_result)

    color_scale = alt.Scale(
        domain=[record["type"] for record in format_cls_result],
        range=["#00A3AF", "#F96702"],
    )

    c = (
        alt.Chart(source)
        .mark_bar(size=50)
        .encode(
            x=alt.X(
                "percentage_start:Q", axis=alt.Axis(title="Style Distribution (%)")
            ),
            x2=alt.X2("percentage_end:Q"),
            color=alt.Color(
                "type:N",
                legend=alt.Legend(title="Attribute"),
                scale=color_scale,
            ),
        )
        .properties(height=150)
    )

    return c