import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union

import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
from torch import Tensor

from xora.utils.torch_utils import append_dims


def simple_diffusion_resolution_dependent_timestep_shift(
    samples: Tensor,
    timesteps: Tensor,
    n: int = 32 * 32,
) -> Tensor:
    if len(samples.shape) == 3:
        _, m, _ = samples.shape
    elif len(samples.shape) in [4, 5]:
        m = math.prod(samples.shape[2:])
    else:
        raise ValueError(
            "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
        )
    snr = (timesteps / (1 - timesteps)) ** 2
    shift_snr = torch.log(snr) + 2 * math.log(m / n)
    shifted_timesteps = torch.sigmoid(0.5 * shift_snr)

    return shifted_timesteps


def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_normal_shift(
    n_tokens: int,
    min_tokens: int = 1024,
    max_tokens: int = 4096,
    min_shift: float = 0.95,
    max_shift: float = 2.05,
) -> Callable[[float], float]:
    m = (max_shift - min_shift) / (max_tokens - min_tokens)
    b = min_shift - m * min_tokens
    return m * n_tokens + b


def sd3_resolution_dependent_timestep_shift(
    samples: Tensor, timesteps: Tensor
) -> Tensor:
    """
    Shifts the timestep schedule as a function of the generated resolution.

    In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images.
    For more details: https://arxiv.org/pdf/2403.03206

    In Flux they later propose a more dynamic resolution dependent timestep shift, see:
    https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66


    Args:
        samples (Tensor): A batch of samples with shape (batch_size, channels, height, width) or
            (batch_size, channels, frame, height, width).
        timesteps (Tensor): A batch of timesteps with shape (batch_size,).

    Returns:
        Tensor: The shifted timesteps.
    """
    if len(samples.shape) == 3:
        _, m, _ = samples.shape
    elif len(samples.shape) in [4, 5]:
        m = math.prod(samples.shape[2:])
    else:
        raise ValueError(
            "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
        )

    shift = get_normal_shift(m)
    return time_shift(shift, 1, timesteps)


class TimestepShifter(ABC):
    @abstractmethod
    def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor:
        pass


@dataclass
class RectifiedFlowSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
    """

    prev_sample: torch.FloatTensor
    pred_original_sample: Optional[torch.FloatTensor] = None


class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
    order = 1

    @register_to_config
    def __init__(
        self,
        num_train_timesteps=1000,
        shifting: Optional[str] = None,
        base_resolution: int = 32**2,
    ):
        super().__init__()
        self.init_noise_sigma = 1.0
        self.num_inference_steps = None
        self.timesteps = self.sigmas = torch.linspace(
            1, 1 / num_train_timesteps, num_train_timesteps
        )
        self.delta_timesteps = self.timesteps - torch.cat(
            [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
        )
        self.shifting = shifting
        self.base_resolution = base_resolution

    def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor:
        if self.shifting == "SD3":
            return sd3_resolution_dependent_timestep_shift(samples, timesteps)
        elif self.shifting == "SimpleDiffusion":
            return simple_diffusion_resolution_dependent_timestep_shift(
                samples, timesteps, self.base_resolution
            )
        return timesteps

    def set_timesteps(
        self,
        num_inference_steps: int,
        samples: Tensor,
        device: Union[str, torch.device] = None,
    ):
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`): The number of diffusion steps used when generating samples.
            samples (`Tensor`): A batch of samples with shape.
            device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
        """
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
        timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
            device
        )
        self.timesteps = self.shift_timesteps(samples, timesteps)
        self.delta_timesteps = self.timesteps - torch.cat(
            [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
        )
        self.num_inference_steps = num_inference_steps
        self.sigmas = self.timesteps

    def scale_model_input(
        self, sample: torch.FloatTensor, timestep: Optional[int] = None
    ) -> torch.FloatTensor:
        # pylint: disable=unused-argument
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.FloatTensor`): input sample
            timestep (`int`, optional): current timestep

        Returns:
            `torch.FloatTensor`: scaled input sample
        """
        return sample

    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: torch.FloatTensor,
        sample: torch.FloatTensor,
        eta: float = 0.0,
        use_clipped_model_output: bool = False,
        generator=None,
        variance_noise: Optional[torch.FloatTensor] = None,
        return_dict: bool = True,
    ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
        # pylint: disable=unused-argument
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            eta (`float`):
                The weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`, defaults to `False`):
                If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
                clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
                `use_clipped_model_output` has no effect.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            variance_noise (`torch.FloatTensor`):
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned,
                otherwise a tuple is returned where the first element is the sample tensor.
        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )

        if timestep.ndim == 0:
            # Global timestep
            current_index = (self.timesteps - timestep).abs().argmin()
            dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
        else:
            # Timestep per token
            assert timestep.ndim == 2
            current_index = (
                (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
            )
            dt = self.delta_timesteps[current_index]
            # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
            dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]

        prev_sample = sample - dt * model_output

        if not return_dict:
            return (prev_sample,)

        return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)

    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        noise: torch.FloatTensor,
        timesteps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        sigmas = timesteps
        sigmas = append_dims(sigmas, original_samples.ndim)
        alphas = 1 - sigmas
        noisy_samples = alphas * original_samples + sigmas * noise
        return noisy_samples