# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple, Union

import torch
import torch.fft as fft

from ..utils.torch_utils import randn_tensor


class FreeInitMixin:
    r"""Mixin class for FreeInit."""

    def enable_free_init(
        self,
        num_iters: int = 3,
        use_fast_sampling: bool = False,
        method: str = "butterworth",
        order: int = 4,
        spatial_stop_frequency: float = 0.25,
        temporal_stop_frequency: float = 0.25,
    ):
        """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.

        This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).

        Args:
            num_iters (`int`, *optional*, defaults to `3`):
                Number of FreeInit noise re-initialization iterations.
            use_fast_sampling (`bool`, *optional*, defaults to `False`):
                Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
                the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
            method (`str`, *optional*, defaults to `butterworth`):
                Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
                FreeInit low pass filter.
            order (`int`, *optional*, defaults to `4`):
                Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
                whereas lower values lead to `gaussian` method behaviour.
            spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
                Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
                the original implementation.
            temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
                Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
                the original implementation.
        """
        self._free_init_num_iters = num_iters
        self._free_init_use_fast_sampling = use_fast_sampling
        self._free_init_method = method
        self._free_init_order = order
        self._free_init_spatial_stop_frequency = spatial_stop_frequency
        self._free_init_temporal_stop_frequency = temporal_stop_frequency

    def disable_free_init(self):
        """Disables the FreeInit mechanism if enabled."""
        self._free_init_num_iters = None

    @property
    def free_init_enabled(self):
        return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None

    def _get_free_init_freq_filter(
        self,
        shape: Tuple[int, ...],
        device: Union[str, torch.dtype],
        filter_type: str,
        order: float,
        spatial_stop_frequency: float,
        temporal_stop_frequency: float,
    ) -> torch.Tensor:
        r"""Returns the FreeInit filter based on filter type and other input conditions."""

        time, height, width = shape[-3], shape[-2], shape[-1]
        mask = torch.zeros(shape)

        if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
            return mask

        if filter_type == "butterworth":

            def retrieve_mask(x):
                return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
        elif filter_type == "gaussian":

            def retrieve_mask(x):
                return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
        elif filter_type == "ideal":

            def retrieve_mask(x):
                return 1 if x <= spatial_stop_frequency * 2 else 0
        else:
            raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")

        for t in range(time):
            for h in range(height):
                for w in range(width):
                    d_square = (
                        ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
                        + (2 * h / height - 1) ** 2
                        + (2 * w / width - 1) ** 2
                    )
                    mask[..., t, h, w] = retrieve_mask(d_square)

        return mask.to(device)

    def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
        r"""Noise reinitialization."""
        # FFT
        x_freq = fft.fftn(x, dim=(-3, -2, -1))
        x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
        noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
        noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))

        # frequency mix
        high_pass_filter = 1 - low_pass_filter
        x_freq_low = x_freq * low_pass_filter
        noise_freq_high = noise_freq * high_pass_filter
        x_freq_mixed = x_freq_low + noise_freq_high  # mix in freq domain

        # IFFT
        x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
        x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real

        return x_mixed

    def _apply_free_init(
        self,
        latents: torch.Tensor,
        free_init_iteration: int,
        num_inference_steps: int,
        device: torch.device,
        dtype: torch.dtype,
        generator: torch.Generator,
    ):
        if free_init_iteration == 0:
            self._free_init_initial_noise = latents.detach().clone()
            return latents, self.scheduler.timesteps

        latent_shape = latents.shape

        free_init_filter_shape = (1, *latent_shape[1:])
        free_init_freq_filter = self._get_free_init_freq_filter(
            shape=free_init_filter_shape,
            device=device,
            filter_type=self._free_init_method,
            order=self._free_init_order,
            spatial_stop_frequency=self._free_init_spatial_stop_frequency,
            temporal_stop_frequency=self._free_init_temporal_stop_frequency,
        )

        current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
        diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

        z_t = self.scheduler.add_noise(
            original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
        ).to(dtype=torch.float32)

        z_rand = randn_tensor(
            shape=latent_shape,
            generator=generator,
            device=device,
            dtype=torch.float32,
        )
        latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
        latents = latents.to(dtype)

        # Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
        if self._free_init_use_fast_sampling:
            num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
            self.scheduler.set_timesteps(num_inference_steps, device=device)

        return latents, self.scheduler.timesteps