# Plug&Play Feature Injection

import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from random import randrange
import PIL
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import custom_bwd, custom_fwd
import torch.nn.functional as F


from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    DDIMScheduler,
)
from diffusers.utils.torch_utils import randn_tensor

from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
    StableDiffusionPipelineOutput,
    retrieve_timesteps,
    PipelineImageInput
)

from src.eunms import Scheduler_Type, Gradient_Averaging_Type, Epsilon_Update_Type

def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
    """
    let a = alpha_t, b = alpha_{t - 1}
    We have a > b,
    x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
    From https://arxiv.org/pdf/2105.05233.pdf, section F.
    """

    a, b = alpha_t, alpha_tm1
    sa = a**0.5
    sb = b**0.5

    return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)


class SDDDIMPipeline(StableDiffusionImg2ImgPipeline):
    # @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image: PipelineImageInput = None,
        strength: float = 1.0,
        num_inversion_steps: Optional[int] = 50,
        timesteps: List[int] = None,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        clip_skip: int = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        opt_lr: float = 0.001,
        opt_iters: int = 1,
        opt_none_inference_steps: bool = False,
        opt_loss_kl_lambda: float = 10.0,
        num_inference_steps: int = 50,
        num_aprox_steps: int = 100,
        **kwargs,
    ):
        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            strength,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

         # 3. Encode input prompt
        text_encoder_lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
            clip_skip=self.clip_skip,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        if ip_adapter_image is not None:
            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
            if self.do_classifier_free_guidance:
                image_embeds = torch.cat([negative_image_embeds, image_embeds])

        # 4. Preprocess image
        image = self.image_processor.preprocess(image)

        # 5. set timesteps
        timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
        timesteps, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device)
        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
        _, num_inference_steps = retrieve_timesteps(self.scheduler_inference, num_inference_steps, device, None)

        # 6. Prepare latent variables
        with torch.no_grad():
            latents = self.prepare_latents(
                image,
                latent_timestep,
                batch_size,
                num_images_per_prompt,
                prompt_embeds.dtype,
                device,
                generator,
            )

        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Add image embeds for IP-Adapter
        added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

        # 7.2 Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)
        prev_timestep = None
        self.prev_z = torch.clone(latents)
        self.prev_z4 = torch.clone(latents)
        self.z_0 = torch.clone(latents)
        g_cpu = torch.Generator().manual_seed(7865)
        self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)


        all_latents = [latents.clone()]
        with self.progress_bar(total=num_inversion_steps) as progress_bar:
            for i, t in enumerate(reversed(timesteps)):

                z_tp1 = self.inversion_step(latents,
                                            t,
                                            prompt_embeds,
                                            added_cond_kwargs,
                                            prev_timestep=prev_timestep,
                                            num_aprox_steps=num_aprox_steps)

                if t in self.scheduler_inference.timesteps:
                    z_tp1 = self.optimize_z_tp1(z_tp1, 
                                                latents, 
                                                t, 
                                                prompt_embeds, 
                                                added_cond_kwargs, 
                                                nom_opt_iters=opt_iters, 
                                                lr=opt_lr, 
                                                opt_loss_kl_lambda=opt_loss_kl_lambda)
                                        
                prev_timestep = t
                latents = z_tp1
                    
                all_latents.append(latents.clone())

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

        image = latents

        # Offload all models
        self.maybe_free_model_hooks()

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None), all_latents
    
    def noise_regularization(self, e_t, noise_pred_optimal):
        for _outer in range(self.cfg.num_reg_steps):
            if self.cfg.lambda_kl>0:
                _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                # l_kld = self.kl_divergence(_var)
                l_kld = self.patchify_latents_kl_divergence(_var, noise_pred_optimal)
                l_kld.backward()
                _grad = _var.grad.detach()
                _grad = torch.clip(_grad, -100, 100)
                e_t = e_t - self.cfg.lambda_kl*_grad
            if self.cfg.lambda_ac>0:
                for _inner in range(self.cfg.num_ac_rolls):
                    _var = torch.autograd.Variable(e_t.detach().clone(), requires_grad=True)
                    l_ac = self.auto_corr_loss(_var)
                    l_ac.backward()
                    _grad = _var.grad.detach()/self.cfg.num_ac_rolls
                    e_t = e_t - self.cfg.lambda_ac*_grad
            e_t = e_t.detach()

        return e_t

    def auto_corr_loss(self, x, random_shift=True):
        B,C,H,W = x.shape
        assert B==1
        x = x.squeeze(0)
        # x must be shape [C,H,W] now
        reg_loss = 0.0
        for ch_idx in range(x.shape[0]):
            noise = x[ch_idx][None, None,:,:]
            while True:
                if random_shift: roll_amount = randrange(noise.shape[2]//2)
                else: roll_amount = 1
                reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=2)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=roll_amount, dims=3)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)
        return reg_loss
    
    def kl_divergence(self, x):
        _mu = x.mean()
        _var = x.var()
        return _var + _mu**2 - 1 - torch.log(_var+1e-7)

    # @torch.no_grad()
    def inversion_step(
        self,
        z_t: torch.tensor,
        t: torch.tensor,
        prompt_embeds,
        added_cond_kwargs,
        prev_timestep: Optional[torch.tensor] = None,
        num_aprox_steps: int = 100
    ) -> torch.tensor:
        extra_step_kwargs = {}

        avg_range = self.cfg.gradient_averaging_first_step_range if t.item() < 250 else self.cfg.gradient_averaging_step_range

        # When doing more then one approximation step in the first step it adds artifacts
        if t.item() < 250:
            num_aprox_steps = min(self.cfg.max_num_aprox_steps_first_step, num_aprox_steps)

        approximated_z_tp1 = z_t.clone()
        nosie_pred_avg = None

        if self.cfg.num_reg_steps > 0:
            z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
            latent_model_input = torch.cat([z_tp1_forward] * 2) if self.do_classifier_free_guidance else z_tp1_forward
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            with torch.no_grad():
                # predict the noise residual
                noise_pred_optimal = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=None,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0].detach()
        else:
            noise_pred_optimal = None

        for i in range(num_aprox_steps + 1):
            latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            with torch.no_grad():
                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=None,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

            if  i >= avg_range[0] and i < avg_range[1]:
                j = i - avg_range[0]
                if nosie_pred_avg is None:
                    nosie_pred_avg = noise_pred.clone()
                else:
                    nosie_pred_avg = j * nosie_pred_avg / (j + 1) + noise_pred / (j + 1)
                if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.EACH_ITER:
                    noise_pred = nosie_pred_avg.clone()

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

            if i >= avg_range[0] or (self.cfg.gradient_averaging_type == Gradient_Averaging_Type.NONE and i > 0):
                noise_pred = self.noise_regularization(noise_pred, noise_pred_optimal)
            
            if self.cfg.scheduler_type == Scheduler_Type.EULER:
                approximated_z_tp1 = self.scheduler.inv_step(noise_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
            else:
                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (
                    self.scheduler.alphas_cumprod[prev_timestep]
                    if prev_timestep is not None
                    else self.scheduler.final_alpha_cumprod
                )
                approximated_z_tp1 = _backward_ddim(
                    x_tm1=z_t,
                    alpha_t=alpha_prod_t,
                    alpha_tm1=alpha_prod_t_prev,
                    eps_xt=noise_pred,
                )

        if self.cfg.gradient_averaging_type == Gradient_Averaging_Type.ON_END and nosie_pred_avg is not None:
            
            nosie_pred_avg = self.noise_regularization(nosie_pred_avg, noise_pred_optimal)
            if self.cfg.scheduler_type == Scheduler_Type.EULER:
                approximated_z_tp1 = self.scheduler.inv_step(nosie_pred_avg, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()
            else:
                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (
                    self.scheduler.alphas_cumprod[prev_timestep]
                    if prev_timestep is not None
                    else self.scheduler.final_alpha_cumprod
                )
                approximated_z_tp1 = _backward_ddim(
                    x_tm1=z_t,
                    alpha_t=alpha_prod_t,
                    alpha_tm1=alpha_prod_t_prev,
                    eps_xt=nosie_pred_avg,
                )

        if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
            latent_model_input = torch.cat([approximated_z_tp1] * 2) if self.do_classifier_free_guidance else approximated_z_tp1
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            with torch.no_grad():
                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=None,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
            
            self.scheduler.step_and_update_noise(noise_pred, t, approximated_z_tp1, z_t, return_dict=False, update_epsilon_type=self.cfg.update_epsilon_type)

        return approximated_z_tp1
    
    def detach_before_opt(self, z_tp1, t, prompt_embeds, added_cond_kwargs):
        z_tp1 = z_tp1.detach()
        t = t.detach()
        prompt_embeds = prompt_embeds.detach()
        return z_tp1, t, prompt_embeds, added_cond_kwargs
    
    def opt_z_tp1_single_step(
        self,
        z_tp1,
        z_t,
        t,
        prompt_embeds,
        added_cond_kwargs,
        lr=0.001,
        opt_loss_kl_lambda=10.0,
    ):
        l1_loss = torch.nn.L1Loss(reduction='sum')
        mse = torch.nn.MSELoss(reduction='sum')
        extra_step_kwargs = {}
        
        self.unet.requires_grad_(False)
        z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
        
        z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
        optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)

        optimizer.zero_grad()
        self.unet.zero_grad()
        latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
        latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)

        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            timestep_cond=None,
            cross_attention_kwargs=self.cross_attention_kwargs,
            added_cond_kwargs=added_cond_kwargs,
            return_dict=False,
        )[0]

        # perform guidance
        if self.do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        # # compute the previous noisy sample x_t -> x_t-1
        z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]

        direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
        kl_loss = torch.tensor([0]).to(z_t.device)
        loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
        
        loss.backward()
        optimizer.step()
        print(f't: {t}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}')

        return z_tp1.detach()
    
    def optimize_z_tp1(
        self,
        z_tp1,
        z_t,
        t,
        prompt_embeds,
        added_cond_kwargs,
        nom_opt_iters=1,
        lr=0.001,
        opt_loss_kl_lambda=10.0,
    ):
        l1_loss = torch.nn.L1Loss(reduction='sum')
        mse = torch.nn.MSELoss(reduction='sum')
        extra_step_kwargs = {}
        
        self.unet.requires_grad_(False)
        z_tp1, t, prompt_embeds, added_cond_kwargs = self.detach_before_opt(z_tp1, t, prompt_embeds, added_cond_kwargs)
        
        z_tp1 = torch.nn.Parameter(z_tp1, requires_grad=True)
        optimizer = torch.optim.SGD([z_tp1], lr=lr, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, verbose=True, patience=5, cooldown=3)
        max_loss = 99999999999999

        z_tp1_forward = self.scheduler.add_noise(self.z_0, self.noise, t.view((1))).detach()
        z_tp1_best = None
        for i in range(nom_opt_iters):
            optimizer.zero_grad()
            self.unet.zero_grad()
            latent_model_input = torch.cat([z_tp1] * 2) if self.do_classifier_free_guidance else z_tp1
            latent_model_input = self.scheduler_inference.scale_model_input(latent_model_input, t)

            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                timestep_cond=None,
                cross_attention_kwargs=self.cross_attention_kwargs,
                added_cond_kwargs=added_cond_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            if self.do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
            
            # # compute the previous noisy sample x_t -> x_t-1
            z_t_hat = self.scheduler_inference.step(noise_pred, t, z_tp1, **extra_step_kwargs, return_dict=False)[0]

            direct_loss = 0.5 * mse(z_t_hat, z_t.detach()) + 0.5 * l1_loss(z_t_hat, z_t.detach())
            kl_loss = self.patchify_latents_kl_divergence(z_tp1, z_tp1_forward)
            loss = 1.0 * direct_loss + opt_loss_kl_lambda * kl_loss
            
            loss.backward()
            best = False
            if loss < max_loss:
                max_loss = loss
                z_tp1_best = torch.clone(z_tp1)
                best = True
            lr_scheduler.step(loss)
            if optimizer.param_groups[0]['lr'] < 9e-06:
                break
            optimizer.step()
            print(f't: {t}\t\t iter: {i}\t total_loss: {format(loss.item(), ".3f")}\t\t direct_loss: {format(direct_loss.item(), ".3f")}\t\t kl_loss: {format(kl_loss.item(), ".3f")}\t\t best: {best}')

        if z_tp1_best is not None:
            z_tp1 = z_tp1_best
        
        self.prev_z4 = torch.clone(z_tp1)

        return z_tp1.detach()

    def opt_inv(self,
                z_t,
                t,
                prompt_embeds,
                added_cond_kwargs,
                prev_timestep,
                nom_opt_iters=1,
                lr=0.001,
                opt_none_inference_steps=False,
                opt_loss_kl_lambda=10.0,
                num_aprox_steps=100):
        
        z_tp1 = self.inversion_step(z_t, t, prompt_embeds, added_cond_kwargs, num_aprox_steps=num_aprox_steps)

        if t in self.scheduler_inference.timesteps:
            z_tp1 = self.optimize_z_tp1(z_tp1, z_t, t, prompt_embeds, added_cond_kwargs, nom_opt_iters=nom_opt_iters, lr=lr, opt_loss_kl_lambda=opt_loss_kl_lambda)

        return z_tp1

    def latent2image(self, latents):
        needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

        if needs_upcasting:
            self.upcast_vae()
            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

        # cast back to fp16 if needed
        # if needs_upcasting:
        #     self.vae.to(dtype=torch.float16)
        
        return image
    
    def patchify_latents_kl_divergence(self, x0, x1):
        # devide x0 and x1 into patches (4x64x64) -> (4x4x4)
        PATCH_SIZE = 4
        NUM_CHANNELS = 4

        def patchify_tensor(input_tensor):
            patches = input_tensor.unfold(1, PATCH_SIZE, PATCH_SIZE).unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
            patches = patches.contiguous().view(-1, NUM_CHANNELS, PATCH_SIZE, PATCH_SIZE)
            return patches
        
        x0 = patchify_tensor(x0)
        x1 = patchify_tensor(x1)

        kl = self.latents_kl_divergence(x0, x1).sum()
        # for i in range(x0.shape[0]):
        #     kl += self.latents_kl_divergence(x0[i], x1[i])
        return kl

    
    def latents_kl_divergence(self, x0, x1):
        EPSILON = 1e-6

        #{\displaystyle D_{\text{KL}}\left({\mathcal {N}}_{0}\parallel {\mathcal {N}}_{1}\right)={\frac {1}{2}}\left(\operatorname {tr} \left(\Sigma _{1}^{-1}\Sigma _{0}\right)-k+\left(\mu _{1}-\mu _{0}\right)^{\mathsf {T}}\Sigma _{1}^{-1}\left(\mu _{1}-\mu _{0}\right)+\ln \left({\frac {\det \Sigma _{1}}{\det \Sigma _{0}}}\right)\right).}
        x0 = x0.view(x0.shape[0], x0.shape[1], -1)
        x1 = x1.view(x1.shape[0], x1.shape[1], -1)
        mu0 = x0.mean(dim=-1)
        mu1 = x1.mean(dim=-1)
        var0 = x0.var(dim=-1)
        var1 = x1.var(dim=-1)
        kl = torch.log((var1 + EPSILON) / (var0 + EPSILON)) + (var0 + (mu0 - mu1)**2) / (var1 + EPSILON) - 1
        kl = torch.abs(kl).sum(dim=-1)
        # kl = torch.linalg.norm(mu0 - mu1) + torch.linalg.norm(var0 - var1)
        # kl *= 1000
        # sigma0 = torch.cov(x0)
        # sigma1 = torch.cov(x1)
        # inv_sigma1 = torch.inverse(sigma1.to(dtype=torch.float64)).to(dtype=x0.dtype)
        # k = x0.shape[1]
        # kl = 0.5 * (torch.trace(inv_sigma1 @ sigma0) - k + (mu1 - mu0).T @ inv_sigma1 @ (mu1 - mu0) + torch.log(torch.det(sigma1) / torch.det(sigma0)))
        return kl

    
class SpecifyGradient(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input_tensor, gt_grad):
        ctx.save_for_backward(gt_grad)

        # dummy loss value
        return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad):
        gt_grad, = ctx.saved_tensors
        batch_size = len(gt_grad)
        return gt_grad / batch_size, None