# -*- coding: utf-8 -*-
# MoNuSeg Dataset
#
# Dataset information: https://monuseg.grand-challenge.org/Home/
# Please Prepare Dataset as described here: docs/readmes/monuseg.md
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen

import logging
from pathlib import Path
from typing import Callable, Union, Tuple

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset

from cell_segmentation.datasets.pannuke import PanNukeDataset
from einops import rearrange

logger = logging.getLogger()
logger.addHandler(logging.NullHandler())


class MoNuSegDataset(Dataset):
    def __init__(
        self,
        dataset_path: Union[Path, str],
        transforms: Callable = None,
        patching: bool = False,
        overlap: int = 0,
    ) -> None:
        """MoNuSeg Dataset

        Args:
            dataset_path (Union[Path, str]): Path to dataset
            transforms (Callable, optional): Transformations to apply on images. Defaults to None.
            patching (bool, optional): If patches with size 256px should be used Otherwise, the entire MoNuSeg images are loaded. Defaults to False.
            overlap: (bool, optional): If overlap should be used for patch sampling. Overlap in pixels.
                Recommended value other than 0 is 64. Defaults to 0.
        Raises:
            FileNotFoundError: If no ground-truth annotation file was found in path
        """
        self.dataset = Path(dataset_path).resolve()
        self.transforms = transforms
        self.masks = []
        self.img_names = []
        self.patching = patching
        self.overlap = overlap

        image_path = self.dataset / "images"
        label_path = self.dataset / "labels"
        self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()]
        self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()]

        # sanity_check
        for idx, image in enumerate(self.images):
            image_name = image.stem
            mask_name = self.masks[idx].stem
            if image_name != mask_name:
                raise FileNotFoundError(f"Annotation for file {image_name} is missing")

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]:
        """Get one item from dataset

        Args:
            index (int): Item to get

        Returns:
            Tuple[torch.Tensor, dict, str]: Trainings-Batch
                * torch.Tensor: Image
                * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map"
                * str: filename
        """
        img_path = self.images[index]
        img = np.array(Image.open(img_path)).astype(np.uint8)

        mask_path = self.masks[index]
        mask = np.load(mask_path, allow_pickle=True)
        mask = mask.astype(np.int64)

        if self.transforms is not None:
            transformed = self.transforms(image=img, mask=mask)
            img = transformed["image"]
            mask = transformed["mask"]

        hv_map = PanNukeDataset.gen_instance_hv_map(mask)
        np_map = mask.copy()
        np_map[np_map > 0] = 1

        # torch convert
        img = torch.Tensor(img).type(torch.float32)
        img = img.permute(2, 0, 1)
        if torch.max(img) >= 5:
            img = img / 255

        if self.patching and self.overlap == 0:
            img = rearrange(img, "c (h i) (w j) -> c h w i j", i=256, j=256)
        if self.patching and self.overlap != 0:
            img = img.unfold(1, 256, 256 - self.overlap).unfold(
                2, 256, 256 - self.overlap
            )

        masks = {
            "instance_map": torch.Tensor(mask).type(torch.int64),
            "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64),
            "hv_map": torch.Tensor(hv_map).type(torch.float32),
        }

        return img, masks, Path(img_path).name

    def __len__(self) -> int:
        """Length of Dataset

        Returns:
            int: Length of Dataset
        """
        return len(self.images)

    def set_transforms(self, transforms: Callable) -> None:
        """Set the transformations, can be used tp exchange transformations

        Args:
            transforms (Callable): PyTorch transformations
        """
        self.transforms = transforms