# python3.7
"""Contains utility functions for image processing.

The module is primarily built on `cv2`. But, differently, we assume all colorful
images are with `RGB` channel order by default. Also, we assume all gray-scale
images to be with shape [height, width, 1].
"""

import os
import cv2
import numpy as np

from .misc import IMAGE_EXTENSIONS
from .misc import check_file_ext

__all__ = [
    'get_blank_image', 'load_image', 'save_image', 'resize_image',
    'add_text_to_image', 'preprocess_image', 'postprocess_image',
    'parse_image_size', 'get_grid_shape', 'list_images_from_dir'
]


def _check_2d_image(image):
    """Checks whether a given image is valid.

    A valid image is expected to be with dtype `uint8`. Also, it should have
    shape like:

    (1) (height, width, 1)  # gray-scale image.
    (2) (height, width, 3)  # colorful image.
    (3) (height, width, 4)  # colorful image with transparency (RGBA)
    """
    assert isinstance(image, np.ndarray)
    assert image.dtype == np.uint8
    assert image.ndim == 3 and image.shape[2] in [1, 3, 4]


def get_blank_image(height, width, channels=3, use_black=True):
    """Gets a blank image, either white of black.

    NOTE: This function will always return an image with `RGB` channel order for
    color image and pixel range [0, 255].

    Args:
        height: Height of the returned image.
        width: Width of the returned image.
        channels: Number of channels. (default: 3)
        use_black: Whether to return a black image. (default: True)
    """
    shape = (height, width, channels)
    if use_black:
        return np.zeros(shape, dtype=np.uint8)
    return np.ones(shape, dtype=np.uint8) * 255


def load_image(path):
    """Loads an image from disk.

    NOTE: This function will always return an image with `RGB` channel order for
    color image and pixel range [0, 255].

    Args:
        path: Path to load the image from.

    Returns:
        An image with dtype `np.ndarray`, or `None` if `path` does not exist.
    """
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if image is None:
        return None

    if image.ndim == 2:
        image = image[:, :, np.newaxis]
    _check_2d_image(image)
    if image.shape[2] == 3:
        return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if image.shape[2] == 4:
        return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
    return image


def save_image(path, image):
    """Saves an image to disk.

    NOTE: The input image (if colorful) is assumed to be with `RGB` channel
    order and pixel range [0, 255].

    Args:
        path: Path to save the image to.
        image: Image to save.
    """
    if image is None:
        return

    _check_2d_image(image)
    if image.shape[2] == 1:
        cv2.imwrite(path, image)
    elif image.shape[2] == 3:
        cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    elif image.shape[2] == 4:
        cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA))


def resize_image(image, *args, **kwargs):
    """Resizes image.

    This is a wrap of `cv2.resize()`.

    NOTE: The channel order of the input image will not be changed.

    Args:
        image: Image to resize.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        An image with dtype `np.ndarray`, or `None` if `image` is empty.
    """
    if image is None:
        return None

    _check_2d_image(image)
    if image.shape[2] == 1:  # Re-expand the squeezed dim of gray-scale image.
        return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis]
    return cv2.resize(image, *args, **kwargs)


def add_text_to_image(image,
                      text='',
                      position=None,
                      font=cv2.FONT_HERSHEY_TRIPLEX,
                      font_size=1.0,
                      line_type=cv2.LINE_8,
                      line_width=1,
                      color=(255, 255, 255)):
    """Overlays text on given image.

    NOTE: The input image is assumed to be with `RGB` channel order.

    Args:
        image: The image to overlay text on.
        text: Text content to overlay on the image. (default: empty)
        position: Target position (bottom-left corner) to add text. If not set,
            center of the image will be used by default. (default: None)
        font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX)
        font_size: Font size of the text added. (default: 1.0)
        line_type: Line type used to depict the text. (default: cv2.LINE_8)
        line_width: Line width used to depict the text. (default: 1)
        color: Color of the text added in `RGB` channel order. (default:
            (255, 255, 255))

    Returns:
        An image with target text overlaid on.
    """
    if image is None or not text:
        return image

    _check_2d_image(image)
    cv2.putText(img=image,
                text=text,
                org=position,
                fontFace=font,
                fontScale=font_size,
                color=color,
                thickness=line_width,
                lineType=line_type,
                bottomLeftOrigin=False)
    return image


def preprocess_image(image, min_val=-1.0, max_val=1.0):
    """Pre-processes image by adjusting the pixel range and to dtype `float32`.

    This function is particularly used to convert an image or a batch of images
    to `NCHW` format, which matches the data type commonly used in deep models.

    NOTE: The input image is assumed to be with pixel range [0, 255] and with
    format `HWC` or `NHWC`. The returned image will be always be with format
    `NCHW`.

    Args:
        image: The input image for pre-processing.
        min_val: Minimum value of the output image.
        max_val: Maximum value of the output image.

    Returns:
        The pre-processed image.
    """
    assert isinstance(image, np.ndarray)

    image = image.astype(np.float64)
    image = image / 255.0 * (max_val - min_val) + min_val

    if image.ndim == 3:
        image = image[np.newaxis]
    assert image.ndim == 4 and image.shape[3] in [1, 3, 4]
    return image.transpose(0, 3, 1, 2)


def postprocess_image(image, min_val=-1.0, max_val=1.0):
    """Post-processes image to pixel range [0, 255] with dtype `uint8`.

    This function is particularly used to handle the results produced by deep
    models.

    NOTE: The input image is assumed to be with format `NCHW`, and the returned
    image will always be with format `NHWC`.

    Args:
        image: The input image for post-processing.
        min_val: Expected minimum value of the input image.
        max_val: Expected maximum value of the input image.

    Returns:
        The post-processed image.
    """
    assert isinstance(image, np.ndarray)

    image = image.astype(np.float64)
    image = (image - min_val) / (max_val - min_val) * 255
    image = np.clip(image + 0.5, 0, 255).astype(np.uint8)

    assert image.ndim == 4 and image.shape[1] in [1, 3, 4]
    return image.transpose(0, 2, 3, 1)


def parse_image_size(obj):
    """Parses an object to a pair of image size, i.e., (height, width).

    Args:
        obj: The input object to parse image size from.

    Returns:
        A two-element tuple, indicating image height and width respectively.

    Raises:
        If the input is invalid, i.e., neither a list or tuple, nor a string.
    """
    if obj is None or obj == '':
        height = 0
        width = 0
    elif isinstance(obj, int):
        height = obj
        width = obj
    elif isinstance(obj, (list, tuple, str, np.ndarray)):
        if isinstance(obj, str):
            splits = obj.replace(' ', '').split(',')
            numbers = tuple(map(int, splits))
        else:
            numbers = tuple(obj)
        if len(numbers) == 0:
            height = 0
            width = 0
        elif len(numbers) == 1:
            height = int(numbers[0])
            width = int(numbers[0])
        elif len(numbers) == 2:
            height = int(numbers[0])
            width = int(numbers[1])
        else:
            raise ValueError('At most two elements for image size.')
    else:
        raise ValueError(f'Invalid type of input: `{type(obj)}`!')

    return (max(0, height), max(0, width))


def get_grid_shape(size, height=0, width=0, is_portrait=False):
    """Gets the shape of a grid based on the size.

    This function makes greatest effort on making the output grid square if
    neither `height` nor `width` is set. If `is_portrait` is set as `False`, the
    height will always be equal to or smaller than the width. For example, if
    input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
    output shape will be (3, 5). Otherwise, the height will always be equal to
    or larger than the width.

    Args:
        size: Size (height * width) of the target grid.
        height: Expected height. If `size % height != 0`, this field will be
            ignored. (default: 0)
        width: Expected width. If `size % width != 0`, this field will be
            ignored. (default: 0)
        is_portrait: Whether to return a portrait size of a landscape size.
            (default: False)

    Returns:
        A two-element tuple, representing height and width respectively.
    """
    assert isinstance(size, int)
    assert isinstance(height, int)
    assert isinstance(width, int)
    if size <= 0:
        return (0, 0)

    if height > 0 and width > 0 and height * width != size:
        height = 0
        width = 0

    if height > 0 and width > 0 and height * width == size:
        return (height, width)
    if height > 0 and size % height == 0:
        return (height, size // height)
    if width > 0 and size % width == 0:
        return (size // width, width)

    height = int(np.sqrt(size))
    while height > 0:
        if size % height == 0:
            width = size // height
            break
        height = height - 1

    return (width, height) if is_portrait else (height, width)


def list_images_from_dir(directory):
    """Lists all images from the given directory.

    NOTE: Do NOT support finding images recursively.

    Args:
        directory: The directory to find images from.

    Returns:
        A list of sorted filenames, with the directory as prefix.
    """
    image_list = []
    for filename in os.listdir(directory):
        if check_file_ext(filename, *IMAGE_EXTENSIONS):
            image_list.append(os.path.join(directory, filename))
    return sorted(image_list)