import cv2
from .plot import *
from abcli import file
from abcli import path
from abcli import string
from abcli.plugins import graphics
from abcli.tasks import host
from abcli.tasks import objects
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import *
import time
from abcli.logging import crash_report
import abcli.logging
import logging

logger = logging.getLogger(__name__)

default_window_size = 28


class Image_Classifier(object):
    def __init__(self):
        self.class_names = []
        self.model = None

        self.params = {
            "convnet": False,
            "object_name": "",
            "model_size": "",
            "window_size": default_window_size,
        }

    def load(self, model_path):
        success, self.class_names = file.load_json(
            f"{model_path}/image_classifier/model/class_names.json"
        )
        if not success:
            return False

        success, self.params = file.load_json(
            f"{model_path}/image_classifier/model/params.json"
        )
        if not success:
            return False

        self.params["object_name"] = path.name(model_path)

        self.params["model_size"] = file.size(f"{model_path}/image_classifier/model")

        try:
            self.model = tf.keras.models.load_model(
                f"{model_path}/image_classifier/model"
            )
        except:
            from abcli.logging import crash_report

            crash_report("image_classifier.load({}) failed".format(model_path))
            return False

        logger.info(
            "{}.load({}x{}:{}): {}{} class(es): {}".format(
                self.__class__.__name__,
                self.params["window_size"],
                self.params["window_size"],
                path.name(model_path),
                "convnet - " if self.params["convnet"] else "",
                len(self.class_names),
                ",".join(self.class_names),
            )
        )
        self.model.summary()

        return True

    def predict(self, test_images, test_labels, output_path="", page_count=-1):
        logger.info(
            "image_classifier.predict({},{}){}".format(
                string.pretty_shape_of_matrix(test_images),
                string.pretty_shape_of_matrix(test_labels),
                "-> {}".format(output_path) if output_path else "",
            )
        )

        prediction_time = time.time()
        predictions = self.model.predict(test_images)
        prediction_time = (time.time() - prediction_time) / test_images.shape[0]
        logger.info(
            "image_classifier.predict(): {} / frame".format(
                string.pretty_duration(
                    prediction_time,
                    include_ms=True,
                )
            )
        )

        if not output_path:
            return True

        if not file.save(
            f"{output_path}/image_classifier/predictions.pyndarray", predictions
        ):
            return False

        if test_labels is not None:
            from sklearn.metrics import confusion_matrix

            logger.info("image_classifier.predict(): rendering confusion_matrix...")

            cm = confusion_matrix(
                test_labels,
                np.argmax(predictions, axis=1),
                labels=range(len(self.class_names)),
                # normalize="true",
            )
            cm = cm / np.sum(cm, axis=1)[:, np.newaxis]
            logger.debug("confusion_matrix: {}".format(cm))

            if not file.save(
                f"{output_path}/image_classifier/model/confusion_matrix.pyndarray", cm
            ):
                return False

            if not graphics.render_confusion_matrix(
                cm,
                self.class_names,
                f"{output_path}/image_classifier/model/confusion_matrix.jpg",
                header=[
                    " | ".join(host.signature()),
                    " | ".join(objects.signature()),
                ],
                footer=self.signature(prediction_time),
            ):
                return False

        if test_labels is not None:
            logger.info(
                "image_classifier.predict(): rendering test_labels distribution..."
            )

            # accepting the risk that if test_labels does not contain any of the largest index
            # this function will return False.
            distribution = np.bincount(test_labels)
            distribution = distribution / np.sum(distribution)

            if not graphics.render_distribution(
                distribution,
                self.class_names,
                f"{output_path}/image_classifier/model/label_distribution.jpg",
                header=[
                    " | ".join(host.signature()),
                    " | ".join(objects.signature()),
                ],
                footer=self.signature(prediction_time),
                title="distribution of test_labels",
            ):
                return False

        max_index = test_images.shape[0]
        if page_count != -1:
            max_index = min(24 * page_count, max_index)
        logger.info(
            f"image_classifier.predict(): rendering {max_index / 24:.0f} sheet(s)..."
        )
        for index in tqdm(range(0, max_index, 24)):
            self.render(
                predictions[index : index + 24],
                None if test_labels is None else test_labels[index : index + 24],
                test_images[index : index + 24],
                "{}/image_classifier/prediction/{:05d}.jpg".format(
                    output_path,
                    int(index / 24),
                ),
                prediction_time,
            )

        return True

    def predict_frame(self, frame):
        prediction_time = time.time()
        try:
            prediction = self.model.predict(
                np.expand_dims(
                    cv2.resize(
                        frame, (self.params["window_size"], self.params["window_size"])
                    )
                    / 255.0,
                    axis=0,
                )
            )
        except:
            from abcli.logging import crash_report

            crash_report("image_classifier.predict_frame() crashed.")
            return False, -1

        prediction_time = time.time() - prediction_time

        output = np.argmax(prediction)

        logger.info(
            "image_classifier.prediction: [{}] -> {} - took {}".format(
                ",".join(
                    [
                        "{}:{:.2f}".format(class_name, value)
                        for class_name, value in zip(self.class_names, prediction[0])
                    ]
                ),
                self.class_names[output],
                string.pretty_duration(
                    prediction_time,
                    include_ms=True,
                    short=True,
                ),
            )
        )

        return True, output

    def render(
        self,
        predictions,
        test_labels,
        test_images,
        output_filename="",
        prediction_time=0,
    ):
        num_rows = 4
        num_cols = 6
        num_images = num_rows * num_cols
        plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
        for i in range(min(num_images, len(predictions))):
            plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
            plot_image(i, predictions[i], test_labels, test_images, self.class_names)
            plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
            plot_value_array(i, predictions[i], test_labels)
            plt.tight_layout()

        if output_filename:
            filename_ = file.auxiliary("prediction", "png")
            plt.savefig(filename_)
            plt.close()

            success, image = file.load_image(filename_)
            if success:
                image = graphics.add_signature(
                    image,
                    [" | ".join(host.signature()), " | ".join(objects.signature())],
                    self.signature(prediction_time),
                )
                file.save_image(output_filename, image)

    def save(self, model_path):
        model_filename = f"{model_path}/image_classifier/model"

        if not file.prepare_for_saving(model_filename):
            return False

        try:
            self.model.save(model_filename)
            logger.info(f"image_classifier.model -> {model_filename}")
        except:
            crash_report(f"-image_classifier: save({model_path}): failed.")
            return False

        self.params["object_name"] = path.name(model_path)

        self.params["model_size"] = file.size(f"{model_path}/image_classifier/model")

        if not file.save_json(
            f"{model_path}/image_classifier/model/class_names.json",
            self.class_names,
        ):
            return False

        if not file.save_json(
            f"{model_path}/image_classifier/model/params.json",
            self.params,
        ):
            return False

        return True

    def signature(self, prediction_time):
        return [
            " | ".join(
                [
                    "image_classifier",
                    self.params["object_name"],
                    string.pretty_bytes(self.params["model_size"])
                    if self.params["model_size"]
                    else "",
                    string.pretty_shape(self.input_shape),
                    "/".join(string.shorten(self.class_names)),
                    "took {} / frame".format(
                        string.pretty_duration(
                            prediction_time,
                            include_ms=True,
                            largest=True,
                            short=True,
                        )
                    ),
                ]
            )
        ]

    @staticmethod
    def train(data_path, model_path, color=False, convnet=True, epochs=10):
        classifier = Image_Classifier()
        classifier.params["convnet"] = convnet

        logger.info(
            "image_classifier.train({}) -{}> {}".format(
                data_path,
                "convnet-" if classifier.params["convnet"] else "",
                model_path,
            )
        )

        success, train_images = file.load(f"{data_path}/train_images.pyndarray")
        if success:
            success, train_labels = file.load(f"{data_path}/train_labels.pyndarray")
        if success:
            success, test_images = file.load(f"{data_path}/test_images.pyndarray")
        if success:
            success, test_labels = file.load(f"{data_path}/test_labels.pyndarray")
        if success:
            success, classifier.class_names = file.load_json(
                f"{data_path}/class_names.json"
            )
        if not success:
            return False

        from tensorflow.keras.utils import to_categorical

        train_labels = to_categorical(train_labels)
        test_labels = to_categorical(test_labels)

        window_size = train_images.shape[1]
        input_shape = (
            (window_size, window_size, 3)
            if color
            else (window_size, window_size, 1)
            if convnet
            else (window_size, window_size)
        )
        logger.info(f"input:{string.pretty_shape(input_shape)}")

        if convnet and not color:
            train_images = np.expand_dims(train_images, axis=3)
            test_images = np.expand_dims(test_images, axis=3)

        for name, thing in zip(
            "train_images,train_labels,test_images,test_labels".split(","),
            [train_images, train_labels, test_images, test_labels],
        ):
            logger.info("{}: {}".format(name, string.pretty_shape_of_matrix(thing)))
        logger.info(
            f"{len(classifier.class_names)} class(es): {', '.join(classifier.class_names)}"
        )

        train_images = train_images / 255.0
        test_images = test_images / 255.0

        if convnet:
            # https://medium.com/swlh/convolutional-neural-networks-for-multiclass-image-classification-a-beginners-guide-to-6dbc09fabbd
            classifier.model = tf.keras.Sequential(
                [
                    tf.keras.layers.Conv2D(
                        filters=48,
                        kernel_size=3,
                        activation="relu",
                        input_shape=input_shape,
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Conv2D(
                        filters=48, kernel_size=3, activation="relu"
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Conv2D(
                        filters=32, kernel_size=3, activation="relu"
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Flatten(),
                    tf.keras.layers.Dense(128, activation="relu"),
                    tf.keras.layers.Dense(64, activation="relu"),
                    tf.keras.layers.Dense(len(classifier.class_names)),
                    tf.keras.layers.Activation("softmax"),
                ]
            )
        else:
            # https://github.com/gato/tensor-on-pi/blob/master/Convolutional%20Neural%20Network%20digit%20predictor.ipynb
            classifier.model = tf.keras.Sequential(
                [
                    tf.keras.layers.Flatten(input_shape=input_shape),
                    tf.keras.layers.Dense(128, activation="relu"),
                    tf.keras.layers.Dense(len(classifier.class_names)),
                    tf.keras.layers.Activation("softmax"),
                ]
            )

        classifier.model.summary()

        classifier.model.compile(
            optimizer="adam",
            loss=tf.keras.losses.categorical_crossentropy,
            metrics=["accuracy"],
        )

        classifier.model.fit(train_images, train_labels, epochs=epochs)

        test_accuracy = float(
            classifier.model.evaluate(test_images, test_labels, verbose=2)[1]
        )
        logger.info("test accuracy: {:.4f}".format(test_accuracy))

        if not file.save_json(
            f"{model_path}/image_classifier/model/evaluation.json",
            {"metrics": {"test_accuracy": test_accuracy}},
        ):
            return False

        if not classifier.save(model_path):
            return False

        return classifier.predict(
            test_images,
            np.argmax(test_labels, axis=1),
            model_path,
            page_count=10,
        )

    @property
    def input_shape(self):
        return self.model.layers[0].input_shape[1:] if self.model.layers else []