from .plot import *
from abcli import file
from abcli import string
import numpy as np
import matplotlib.pyplot as plt
import abcli.logging
import logging

logger = logging.getLogger(__name__)


class Image_Classifier(object):
    def __init__(self):
        self.class_names = []
        self.model = None
        self.params = {"convnet": False}

        self.object_name = ""
        self.model_size = ""

    def load(self, model_path):
        success, self.class_names = file.load_json(f"{model_path}/class_names.json")
        if not success:
            return False

        success, self.params = file.load_json(f"{model_path}/params.json", default={})
        if not success:
            return False

        self.model_size = file.size(f"{model_path}/image_classifier/model")

        try:
            self.model = tf.keras.models.load_model(
                f"{model_path}/image_classifier/model"
            )
        except:
            from abcli.logging import crash_report

            crash_report("image_classifier.load({}) failed".format(model_path))
            return False

        self.window_size = int(
            cache.read("{}.window_size".format(path.name(model_path)))
        )

        logger.info(
            "{}.load({}x{}:{}): {}{} class(es): {}".format(
                self.__class__.__name__,
                self.window_size,
                self.window_size,
                path.name(model_path),
                "convnet - " if self.params["convnet"] else "",
                len(self.class_names),
                ",".join(self.class_names),
            )
        )
        self.model.summary()

        self.object_name = path.name(model_path)

        return True

    def predict(self, test_images, test_labels, output_path="", options=""):
        options = Options(options).default("cache", False).default("page_count", -1)

        logger.info(
            "image_classifier.predict({},{}){}".format(
                string.pretty_size_of_matrix(test_images),
                string.pretty_size_of_matrix(test_labels),
                "-> {}".format(output_path) if output_path else "",
            )
        )

        prediction_time = time.time()
        predictions = self.model.predict(test_images)
        prediction_time = (time.time() - prediction_time) / test_images.shape[0]
        logger.info(
            "image_classifier.predict(): {} / frame".format(
                string.pretty_duration(prediction_time, include_ms=True)
            )
        )

        if not output_path:
            return True

        if not file.save("{}/predictions.pyndarray".format(output_path), predictions):
            return False

        if test_labels is not None:
            from sklearn.metrics import confusion_matrix

            logger.info("image_classifier.predict(): rendering confusion_matrix...")

            cm = confusion_matrix(
                test_labels,
                np.argmax(predictions, axis=1),
                labels=range(len(self.class_names)),
                # normalize="true",
            )
            cm = cm / np.sum(cm, axis=1)[:, np.newaxis]
            logger.debug("confusion_matrix: {}".format(cm))

            if options["cache"]:
                if not cache.write("{}.confusion_matrix".format(self.object_name), cm):
                    return False

            if not file.save("{}/confusion_matrix.pyndarray".format(output_path), cm):
                return False

            if not graphics.render_confusion_matrix(
                cm,
                self.class_names,
                "{}/Data/0/info.jpg".format(output_path),
                {
                    "header": [
                        " | ".join(host.signature()),
                        " | ".join(objects.signature()),
                    ],
                    "footer": self.signature(prediction_time),
                },
            ):
                return False

        if test_labels is not None:
            logger.info(
                "image_classifier.predict(): rendering test_labels distribution..."
            )

            # accepting the risk that if test_labels does not contain any of the largest index
            # this function will return False.
            distribution = np.bincount(test_labels)
            distribution = distribution / np.sum(distribution)

            if not graphics.render_distribution(
                distribution,
                self.class_names,
                "{}/Data/1/info.jpg".format(output_path),
                {
                    "header": [
                        " | ".join(host.signature()),
                        " | ".join(objects.signature()),
                    ],
                    "footer": self.signature(prediction_time),
                    "title": "distribution of test_labels",
                },
            ):
                return False

        max_index = test_images.shape[0]
        if options["page_count"] != -1:
            max_index = min(24 * options["page_count"], max_index)
        offset = int(np.max(np.array(objects.list_of_frames(output_path) + [-1]))) + 1
        logger.info(
            "image_classifier.predict(offset={}): rendering {} frame(s)...".format(
                offset, max_index
            )
        )
        for index in tqdm(range(0, max_index, 24)):
            self.render(
                predictions[index : index + 24],
                None if test_labels is None else test_labels[index : index + 24],
                test_images[index : index + 24],
                "{}/Data/{}/info.jpg".format(output_path, int(index / 24) + offset),
                prediction_time,
            )

        return True

    def predict_frame(self, frame):
        prediction_time = time.time()
        try:
            prediction = self.model.predict(
                np.expand_dims(
                    cv2.resize(frame, (self.window_size, self.window_size)) / 255.0,
                    axis=0,
                )
            )
        except:
            from abcli.logging import crash_report

            crash_report("image_classifier.predict_frame() crashed.")
            return False, -1

        prediction_time = time.time() - prediction_time

        output = np.argmax(prediction)

        logger.info(
            "image_classifier.prediction: [{}] -> {} - took {}".format(
                ",".join(
                    [
                        "{}:{:.2f}".format(class_name, value)
                        for class_name, value in zip(self.class_names, prediction[0])
                    ]
                ),
                self.class_names[output],
                string.pretty_duration(
                    prediction_time,
                    include_ms=True,
                    short=True,
                ),
            )
        )

        return True, output

    def render(
        self,
        predictions,
        test_labels,
        test_images,
        output_filename="",
        prediction_time=0,
    ):
        num_rows = 4
        num_cols = 6
        num_images = num_rows * num_cols
        plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
        for i in range(min(num_images, len(predictions))):
            plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
            plot_image(i, predictions[i], test_labels, test_images, self.class_names)
            plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
            plot_value_array(i, predictions[i], test_labels)
            plt.tight_layout()

        if output_filename:
            filename_ = file.auxiliary("prediction", "png")
            plt.savefig(filename_)
            plt.close()

            success, image = file.load_image(filename_)
            if success:
                image = graphics.add_signature(
                    image,
                    [" | ".join(host.signature()), " | ".join(objects.signature())],
                    self.signature(prediction_time),
                )
                file.save_image(output_filename, image)

    def save(self, model_path):
        model_filename = "{}/image_classifier/model".format(model_path)
        file.prepare_for_saving(model_filename)
        try:
            self.model.save(model_filename)
            logger.info("image_classifier.model -> {}".format(model_filename))
        except:
            from abcli.logging import crash_report

            crash_report("image_classifier.save({}) failed".format(model_path))
            return False

        self.object_name = path.name(model_path)

        self.model_size = file.size("{}/image_classifier/model".format(model_path))

        if not file.save_json(
            "{}/class_names.json".format(model_path), self.class_names
        ):
            return False

        if not file.save_json("{}/params.json".format(model_path), self.params):
            return False

        return True

    def signature(self, prediction_time):
        return [
            " | ".join(
                [
                    "image_classifier",
                    self.object_name,
                    string.pretty_bytes(self.model_size) if self.model_size else "",
                    string.pretty_size(self.input_shape),
                    "/".join(string.shorten(self.class_names)),
                    "took {} / frame".format(
                        string.pretty_duration(
                            prediction_time,
                            include_ms=True,
                            longest=True,
                            short=True,
                        )
                    ),
                ]
            )
        ]

    @staticmethod
    def train(data_path, model_path, options=""):
        options = (
            Options(options)
            .default("color", False)
            .default("convnet", True)
            .default("epochs", 10)
        )

        classifier = image_classifier()
        classifier.params["convnet"] = options["convnet"]

        logger.info(
            "image_classifier.train({}) -{}> {}".format(
                data_path,
                "convnet-" if classifier.params["convnet"] else "",
                model_path,
            )
        )

        success, train_images = file.load("{}/train_images.pyndarray".format(data_path))
        if success:
            success, train_labels = file.load(f"{data_path}/train_labels.pyndarray")
        if success:
            success, test_images = file.load(f"{data_path}/test_images.pyndarray")
        if success:
            success, test_labels = file.load(f"{data_path}/test_labels.pyndarray")
        if success:
            success, classifier.class_names = file.load_json(
                f"{data_path}/class_names.json"
            )
        if not success:
            return False

        from tensorflow.keras.utils import to_categorical

        train_labels = to_categorical(train_labels)
        test_labels = to_categorical(test_labels)

        window_size = train_images.shape[1]
        input_shape = (
            (window_size, window_size, 3)
            if options["color"]
            else (window_size, window_size, 1)
            if options["convnet"]
            else (window_size, window_size)
        )
        logger.info(f"input_shape:{string.pretty_size(input_shape)}")

        if options["convnet"] and not options["color"]:
            train_images = np.expand_dims(train_images, axis=3)
            test_images = np.expand_dims(test_images, axis=3)

        for name, thing in zip(
            "train_images,train_labels,test_images,test_labels".split(","),
            [train_images, train_labels, test_images, test_labels],
        ):
            logger.info("{}: {}".format(name, string.pretty_size_of_matrix(thing)))
        logger.info(
            "{} class(es): {}".format(
                len(classifier.class_names), classifier.class_names
            )
        )

        train_images = train_images / 255.0
        test_images = test_images / 255.0

        if options["convnet"]:
            # https://medium.com/swlh/convolutional-neural-networks-for-multiclass-image-classification-a-beginners-guide-to-6dbc09fabbd
            classifier.model = tf.keras.Sequential(
                [
                    tf.keras.layers.Conv2D(
                        filters=48,
                        kernel_size=3,
                        activation="relu",
                        input_shape=input_shape,
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Conv2D(
                        filters=48, kernel_size=3, activation="relu"
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Conv2D(
                        filters=32, kernel_size=3, activation="relu"
                    ),
                    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
                    tf.keras.layers.Flatten(),
                    tf.keras.layers.Dense(128, activation="relu"),
                    tf.keras.layers.Dense(64, activation="relu"),
                    tf.keras.layers.Dense(len(classifier.class_names)),
                    tf.keras.layers.Activation("softmax"),
                ]
            )
        else:
            # https://github.com/gato/tensor-on-pi/blob/master/Convolutional%20Neural%20Network%20digit%20predictor.ipynb
            classifier.model = tf.keras.Sequential(
                [
                    tf.keras.layers.Flatten(input_shape=input_shape),
                    tf.keras.layers.Dense(128, activation="relu"),
                    tf.keras.layers.Dense(len(classifier.class_names)),
                    tf.keras.layers.Activation("softmax"),
                ]
            )

        classifier.model.summary()

        classifier.model.compile(
            optimizer="adam",
            loss=tf.keras.losses.categorical_crossentropy,
            metrics=["accuracy"],
        )

        classifier.model.fit(train_images, train_labels, epochs=options["epochs"])

        test_accuracy = float(
            classifier.model.evaluate(test_images, test_labels, verbose=2)[1]
        )
        logger.info("test accuracy: {:.4f}".format(test_accuracy))

        if not file.save_json(
            f"{model_path}/eval.json",
            {"metrics": {"test_accuracy": test_accuracy}},
        ):
            return False

        if not classifier.save(model_path):
            return False

        return classifier.predict(
            test_images,
            np.argmax(test_labels, axis=1),
            model_path,
            cache=True,
            page_count=10,
        )

    @property
    def input_shape(self):
        return self.model.layers[0].input_shape[1:] if self.model.layers else []