# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import torch
from torch.nn import functional as F

from densepose.data.meshes.catalog import MeshCatalog
from densepose.structures.mesh import load_mesh_symmetry
from densepose.structures.transform_data import DensePoseTransformData


class DensePoseDataRelative:
    """
    Dense pose relative annotations that can be applied to any bounding box:
        x - normalized X coordinates [0, 255] of annotated points
        y - normalized Y coordinates [0, 255] of annotated points
        i - body part labels 0,...,24 for annotated points
        u - body part U coordinates [0, 1] for annotated points
        v - body part V coordinates [0, 1] for annotated points
        segm - 256x256 segmentation mask with values 0,...,14
    To obtain absolute x and y data wrt some bounding box one needs to first
    divide the data by 256, multiply by the respective bounding box size
    and add bounding box offset:
        x_img = x0 + x_norm * w / 256.0
        y_img = y0 + y_norm * h / 256.0
    Segmentation masks are typically sampled to get image-based masks.
    """

    # Key for normalized X coordinates in annotation dict
    X_KEY = "dp_x"
    # Key for normalized Y coordinates in annotation dict
    Y_KEY = "dp_y"
    # Key for U part coordinates in annotation dict (used in chart-based annotations)
    U_KEY = "dp_U"
    # Key for V part coordinates in annotation dict (used in chart-based annotations)
    V_KEY = "dp_V"
    # Key for I point labels in annotation dict (used in chart-based annotations)
    I_KEY = "dp_I"
    # Key for segmentation mask in annotation dict
    S_KEY = "dp_masks"
    # Key for vertex ids (used in continuous surface embeddings annotations)
    VERTEX_IDS_KEY = "dp_vertex"
    # Key for mesh id (used in continuous surface embeddings annotations)
    MESH_NAME_KEY = "ref_model"
    # Number of body parts in segmentation masks
    N_BODY_PARTS = 14
    # Number of parts in point labels
    N_PART_LABELS = 24
    MASK_SIZE = 256

    def __init__(self, annotation, cleanup=False):
        self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY])
        self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY])
        if (
            DensePoseDataRelative.I_KEY in annotation
            and DensePoseDataRelative.U_KEY in annotation
            and DensePoseDataRelative.V_KEY in annotation
        ):
            self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY])
            self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY])
            self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY])
        if (
            DensePoseDataRelative.VERTEX_IDS_KEY in annotation
            and DensePoseDataRelative.MESH_NAME_KEY in annotation
        ):
            self.vertex_ids = torch.as_tensor(
                annotation[DensePoseDataRelative.VERTEX_IDS_KEY], dtype=torch.long
            )
            self.mesh_id = MeshCatalog.get_mesh_id(annotation[DensePoseDataRelative.MESH_NAME_KEY])
        if DensePoseDataRelative.S_KEY in annotation:
            self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation)
        self.device = torch.device("cpu")
        if cleanup:
            DensePoseDataRelative.cleanup_annotation(annotation)

    def to(self, device):
        if self.device == device:
            return self
        new_data = DensePoseDataRelative.__new__(DensePoseDataRelative)
        new_data.x = self.x.to(device)
        new_data.y = self.y.to(device)
        for attr in ["i", "u", "v", "vertex_ids", "segm"]:
            if hasattr(self, attr):
                setattr(new_data, attr, getattr(self, attr).to(device))
        if hasattr(self, "mesh_id"):
            new_data.mesh_id = self.mesh_id
        new_data.device = device
        return new_data

    @staticmethod
    def extract_segmentation_mask(annotation):
        import pycocotools.mask as mask_utils

        # TODO: annotation instance is accepted if it contains either
        # DensePose segmentation or instance segmentation. However, here we
        # only rely on DensePose segmentation
        poly_specs = annotation[DensePoseDataRelative.S_KEY]
        if isinstance(poly_specs, torch.Tensor):
            # data is already given as mask tensors, no need to decode
            return poly_specs
        segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32)
        if isinstance(poly_specs, dict):
            if poly_specs:
                mask = mask_utils.decode(poly_specs)
                segm[mask > 0] = 1
        else:
            for i in range(len(poly_specs)):
                poly_i = poly_specs[i]
                if poly_i:
                    mask_i = mask_utils.decode(poly_i)
                    segm[mask_i > 0] = i + 1
        return segm

    @staticmethod
    def validate_annotation(annotation):
        for key in [
            DensePoseDataRelative.X_KEY,
            DensePoseDataRelative.Y_KEY,
        ]:
            if key not in annotation:
                return False, "no {key} data in the annotation".format(key=key)
        valid_for_iuv_setting = all(
            key in annotation
            for key in [
                DensePoseDataRelative.I_KEY,
                DensePoseDataRelative.U_KEY,
                DensePoseDataRelative.V_KEY,
            ]
        )
        valid_for_cse_setting = all(
            key in annotation
            for key in [
                DensePoseDataRelative.VERTEX_IDS_KEY,
                DensePoseDataRelative.MESH_NAME_KEY,
            ]
        )
        if not valid_for_iuv_setting and not valid_for_cse_setting:
            return (
                False,
                "expected either {} (IUV setting) or {} (CSE setting) annotations".format(
                    ", ".join(
                        [
                            DensePoseDataRelative.I_KEY,
                            DensePoseDataRelative.U_KEY,
                            DensePoseDataRelative.V_KEY,
                        ]
                    ),
                    ", ".join(
                        [
                            DensePoseDataRelative.VERTEX_IDS_KEY,
                            DensePoseDataRelative.MESH_NAME_KEY,
                        ]
                    ),
                ),
            )
        return True, None

    @staticmethod
    def cleanup_annotation(annotation):
        for key in [
            DensePoseDataRelative.X_KEY,
            DensePoseDataRelative.Y_KEY,
            DensePoseDataRelative.I_KEY,
            DensePoseDataRelative.U_KEY,
            DensePoseDataRelative.V_KEY,
            DensePoseDataRelative.S_KEY,
            DensePoseDataRelative.VERTEX_IDS_KEY,
            DensePoseDataRelative.MESH_NAME_KEY,
        ]:
            if key in annotation:
                del annotation[key]

    def apply_transform(self, transforms, densepose_transform_data):
        self._transform_pts(transforms, densepose_transform_data)
        if hasattr(self, "segm"):
            self._transform_segm(transforms, densepose_transform_data)

    def _transform_pts(self, transforms, dp_transform_data):
        import detectron2.data.transforms as T

        # NOTE: This assumes that HorizFlipTransform is the only one that does flip
        do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
        if do_hflip:
            self.x = self.MASK_SIZE - self.x
            if hasattr(self, "i"):
                self._flip_iuv_semantics(dp_transform_data)
            if hasattr(self, "vertex_ids"):
                self._flip_vertices()

        for t in transforms.transforms:
            if isinstance(t, T.RotationTransform):
                xy_scale = np.array((t.w, t.h)) / DensePoseDataRelative.MASK_SIZE
                xy = t.apply_coords(np.stack((self.x, self.y), axis=1) * xy_scale)
                self.x, self.y = torch.tensor(xy / xy_scale, dtype=self.x.dtype).T

    def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None:
        i_old = self.i.clone()
        uv_symmetries = dp_transform_data.uv_symmetries
        pt_label_symmetries = dp_transform_data.point_label_symmetries
        for i in range(self.N_PART_LABELS):
            if i + 1 in i_old:
                annot_indices_i = i_old == i + 1
                if pt_label_symmetries[i + 1] != i + 1:
                    self.i[annot_indices_i] = pt_label_symmetries[i + 1]
                u_loc = (self.u[annot_indices_i] * 255).long()
                v_loc = (self.v[annot_indices_i] * 255).long()
                self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc].to(
                    device=self.u.device
                )
                self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc].to(
                    device=self.v.device
                )

    def _flip_vertices(self):
        mesh_info = MeshCatalog[MeshCatalog.get_mesh_name(self.mesh_id)]
        mesh_symmetry = (
            load_mesh_symmetry(mesh_info.symmetry) if mesh_info.symmetry is not None else None
        )
        self.vertex_ids = mesh_symmetry["vertex_transforms"][self.vertex_ids]

    def _transform_segm(self, transforms, dp_transform_data):
        import detectron2.data.transforms as T

        # NOTE: This assumes that HorizFlipTransform is the only one that does flip
        do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
        if do_hflip:
            self.segm = torch.flip(self.segm, [1])
            self._flip_segm_semantics(dp_transform_data)

        for t in transforms.transforms:
            if isinstance(t, T.RotationTransform):
                self._transform_segm_rotation(t)

    def _flip_segm_semantics(self, dp_transform_data):
        old_segm = self.segm.clone()
        mask_label_symmetries = dp_transform_data.mask_label_symmetries
        for i in range(self.N_BODY_PARTS):
            if mask_label_symmetries[i + 1] != i + 1:
                self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1]

    def _transform_segm_rotation(self, rotation):
        self.segm = F.interpolate(self.segm[None, None, :], (rotation.h, rotation.w)).numpy()
        self.segm = torch.tensor(rotation.apply_segmentation(self.segm[0, 0]))[None, None, :]
        self.segm = F.interpolate(self.segm, [DensePoseDataRelative.MASK_SIZE] * 2)[0, 0]