from . import *
from abcli import file
from abcli import string
import cv2
import numpy as np
import os.path
import abcli.logging
import logging

logger = logging.getLogger(__name__)


def eval(input_path, output_path):
    from sklearn.metrics import accuracy_score

    report = {"accuracy": None}

    success, ground_truth = file.load(f"{input_path}/test_labels.pyndarray")
    if success:
        logger.info(
            "groundtruth: {} - {}".format(
                string.pretty_shape_of_matrix(ground_truth),
                ",".join([str(value) for value in ground_truth[:10]] + ["..."]),
            )
        )
        success, predictions = file.load(
            f"{input_path}/image_classifier/predictions.pyndarray"
        )

    if success:
        predictions = np.argmax(predictions, axis=1).astype(np.uint8)
        logger.info(
            "predictions: {} - {}".format(
                string.pretty_shape_of_matrix(predictions),
                ",".join([str(value) for value in predictions[:10]] + ["..."]),
            )
        )

        report["accuracy"] = accuracy_score(predictions, ground_truth)

        logger.info(
            "image_classifier.eval({}->{}): {:.2f}%".format(
                input_path, output_path, 100 * report["accuracy"]
            )
        )

    return file.save_json(os.path.join(output_path, "evaluation_report.json"), report)


def preprocess(
    output_path,
    objects="",
    infer_annotation=True,
    purpose="predict",
    test_size=1.0 / 6,
    window_size=28,
):
    if objects:
        logger.info(
            "image_classifier.preprocess({}{})->{} - {}x{} - for {}".format(
                ",".join(objects),
                " + annotation" if infer_annotation else "",
                output_path,
                window_size,
                window_size,
                purpose,
            )
        )

        annotations = []
        list_of_images = []
        for index, object in enumerate(objects):
            list_of_images_ = [
                "{}/Data/{}/camera.jpg".format(object, frame)
                for frame in objects.list_of_frames(object)
            ]

            annotations += len(list_of_images_) * [index]
            list_of_images += list_of_images_

        annotations = np.array(annotations) if infer_annotation else []
    else:
        logger.info(
            "image_classifier.preprocess({}) - {}x{} - for {}".format(
                output_path,
                window_size,
                window_size,
                purpose,
            )
        )

        list_of_images = [
            "{}/Data/{}/camera.jpg".format(output_path, frame)
            for frame in objects.list_of_frames(output_path)
        ]

        annotations = np.array(
            file.load_json(
                f"{output_path}/annotations.json".format(),
                civilized=True,
                default=None,
            )[1]
        ).astype(np.uint8)

    if len(annotations) and len(list_of_images) != len(annotations):
        logger.error(
            f"-{name}: preprocess: mismatch between frame and annotation counts: {len(list_of_images):,g} != {len(annotations):,g}"
        )
        return False
    logger.info("{:,} frame(s)".format(len(list_of_images)))

    tensor = np.zeros(
        (len(list_of_images), window_size, window_size, 3),
        dtype=np.uint8,
    )

    error_count = 0
    for index, filename in enumerate(list_of_images):
        logger.info("+= {}".format(filename))
        success_, image = file.load_image(filename)
        if success_:
            try:
                tensor[index, :, :, :] = cv2.resize(image, (window_size, window_size))
            except:
                from abcli.logging import crash_report

                crash_report("image_classifier.preprocess() failed")
                success_ = False

        if not success_:
            error_count += 1
    logger.info(
        "tensor: {}{}".format(
            string.pretty_shape_of_matrix(tensor),
            " {} error(s)".format(error_count) if error_count else "",
        )
    )

    success = False
    if purpose == "predict":
        if not file.save("{}/test_images.pyndarray".format(output_path), tensor):
            return False
        if len(annotations):
            if not file.save(
                "{}/test_labels.pyndarray".format(output_path), annotations
            ):
                return False
        success = True
    elif purpose == "train":
        if not len(annotations):
            logger.error(f"-{name}: preprocess: annotations are not provided.")
            return False

        from sklearn.model_selection import train_test_split

        (
            tensor_train,
            tensor_test,
            annotations_train,
            annotations_test,
        ) = train_test_split(tensor, annotations, test_size=test_size)
        logger.info(
            "test-train split: {:.0f}%-{:.0f}% ".format(
                len(annotations_test) / len(annotations) * 100,
                len(annotations_train) / len(annotations) * 100,
            )
        )
        logger.info(
            "tensor_train: {}".format(string.pretty_shape_of_matrix(tensor_train))
        )
        logger.info(
            "tensor_test: {}".format(string.pretty_shape_of_matrix(tensor_test))
        )
        logger.info(
            "annotations_train: {}".format(
                string.pretty_shape_of_matrix(annotations_train)
            )
        )
        logger.info(
            "annotations_test: {}".format(
                string.pretty_shape_of_matrix(annotations_test)
            )
        )

        success = (
            file.save("{}/train_images.pyndarray".format(output_path), tensor_train)
            and file.save("{}/test_images.pyndarray".format(output_path), tensor_test)
            and file.save(
                "{}/train_labels.pyndarray".format(output_path), annotations_train
            )
            and file.save(
                "{}/test_labels.pyndarray".format(output_path), annotations_test
            )
        )
    else:
        logger.error(f"-{name}: preprocess: {purpose}: purpose not found.")

    return success