from typing import Dict
import numpy as np
import torch
import kornia.augmentation as K
from kornia.geometry.transform import warp_perspective

# Adapted from Kornia
class GeometricSequential:
    def __init__(self, *transforms, align_corners=True) -> None:
        self.transforms = transforms
        self.align_corners = align_corners

    def __call__(self, x, mode="bilinear"):
        b, c, h, w = x.shape
        M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
        for t in self.transforms:
            if np.random.rand() < t.p:
                M = M.matmul(
                    t.compute_transformation(x, t.generate_parameters((b, c, h, w)))
                )
        return (
            warp_perspective(
                x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
            ),
            M,
        )

    def apply_transform(self, x, M, mode="bilinear"):
        b, c, h, w = x.shape
        return warp_perspective(
            x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
        )


class RandomPerspective(K.RandomPerspective):
    def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
        distortion_scale = torch.as_tensor(
            self.distortion_scale, device=self._device, dtype=self._dtype
        )
        return self.random_perspective_generator(
            batch_shape[0],
            batch_shape[-2],
            batch_shape[-1],
            distortion_scale,
            self.same_on_batch,
            self.device,
            self.dtype,
        )

    def random_perspective_generator(
        self,
        batch_size: int,
        height: int,
        width: int,
        distortion_scale: torch.Tensor,
        same_on_batch: bool = False,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float32,
    ) -> Dict[str, torch.Tensor]:
        r"""Get parameters for ``perspective`` for a random perspective transform.

        Args:
            batch_size (int): the tensor batch size.
            height (int) : height of the image.
            width (int): width of the image.
            distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
            same_on_batch (bool): apply the same transformation across the batch. Default: False.
            device (torch.device): the device on which the random numbers will be generated. Default: cpu.
            dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

        Returns:
            params Dict[str, torch.Tensor]: parameters to be passed for transformation.
                - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
                - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).

        Note:
            The generated random numbers are not reproducible across different devices and dtypes.
        """
        if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
            raise AssertionError(
                f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
            )
        if not (
            type(height) is int and height > 0 and type(width) is int and width > 0
        ):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )

        start_points: torch.Tensor = torch.tensor(
            [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
            device=distortion_scale.device,
            dtype=distortion_scale.dtype,
        ).expand(batch_size, -1, -1)

        # generate random offset not larger than half of the image
        fx = distortion_scale * width / 2
        fy = distortion_scale * height / 2

        factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
        offset = (torch.rand_like(start_points) - 0.5) * 2
        end_points = start_points + factor * offset

        return dict(start_points=start_points, end_points=end_points)