import torch
import numpy as np
from safetensors import safe_open


def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
    """
    See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

    Args:
        timesteps (`torch.Tensor`):
            generate embedding vectors at these timesteps
        embedding_dim (`int`, *optional*, defaults to 512):
            dimension of the embeddings to generate
        dtype:
            data type of the generated embeddings

    Returns:
        `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
    """
    assert len(w.shape) == 1
    w = w * 1000.0

    half_dim = embedding_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
    emb = w.to(dtype)[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1))
    assert emb.shape == (w.shape[0], embedding_dim)
    return emb


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]


# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
    c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
    c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
    return c_skip, c_out


# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
    if prediction_type == "epsilon":
        sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
        alphas = extract_into_tensor(alphas, timesteps, sample.shape)
        pred_x_0 = (sample - sigmas * model_output) / alphas
    elif prediction_type == "v_prediction":
        sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
        alphas = extract_into_tensor(alphas, timesteps, sample.shape)
        pred_x_0 = alphas * sample - sigmas * model_output
    else:
        raise ValueError(
            f"Prediction type {prediction_type} currently not supported.")

    return pred_x_0


def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
    if prediction_type == "epsilon":
        sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
        alphas = extract_into_tensor(alphas, timesteps, sample.shape)
        sample = sample * alphas / sigmas
    else:
        raise ValueError(
            f"Prediction type {prediction_type} currently not supported.")

    return sample


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        # DDIM sampling parameters
        step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
        # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
        self.ddim_timesteps_prev = np.asarray(
            [0] + self.ddim_timesteps[:-1].tolist()
        )
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        # convert to torch tensors
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_timesteps_prev = torch.from_numpy(
            self.ddim_timesteps_prev).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(
            self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
            device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(
            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
        return x_prev


@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def convert_lcm_lora(unet, path, alpha=1.0):

    if path.endswith(("ckpt",)):
        state_dict = torch.load(path, map_location="cpu")
    else:
        state_dict = {}
        with safe_open(path, framework="pt", device="cpu") as f:
            for key in f.keys():
                state_dict[key] = f.get_tensor(key)

    num_alpha = 0
    for key in state_dict.keys():
        if "alpha" in key:
            num_alpha += 1

    lora_keys = [k for k in state_dict.keys(
    ) if k.endswith("lora_down.weight")]

    updated_state_dict = {}
    for key in lora_keys:
        lora_name = key.split(".")[0]

        if lora_name.startswith("lora_unet_"):
            diffusers_name = key.replace("lora_unet_", "").replace("_", ".")

            if "input.blocks" in diffusers_name:
                diffusers_name = diffusers_name.replace(
                    "input.blocks", "down_blocks")
            else:
                diffusers_name = diffusers_name.replace(
                    "down.blocks", "down_blocks")

            if "middle.block" in diffusers_name:
                diffusers_name = diffusers_name.replace(
                    "middle.block", "mid_block")
            else:
                diffusers_name = diffusers_name.replace(
                    "mid.block", "mid_block")
            if "output.blocks" in diffusers_name:
                diffusers_name = diffusers_name.replace(
                    "output.blocks", "up_blocks")
            else:
                diffusers_name = diffusers_name.replace(
                    "up.blocks", "up_blocks")

            diffusers_name = diffusers_name.replace(
                "transformer.blocks", "transformer_blocks")
            diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
            diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
            diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
            diffusers_name = diffusers_name.replace(
                "to.out.0.lora", "to_out_lora")
            diffusers_name = diffusers_name.replace("proj.in", "proj_in")
            diffusers_name = diffusers_name.replace("proj.out", "proj_out")
            diffusers_name = diffusers_name.replace(
                "time.emb.proj", "time_emb_proj")
            diffusers_name = diffusers_name.replace(
                "conv.shortcut", "conv_shortcut")

            updated_state_dict[diffusers_name] = state_dict[key]
            up_diffusers_name = diffusers_name.replace(".down.", ".up.")
            up_key = key.replace("lora_down.weight", "lora_up.weight")
            updated_state_dict[up_diffusers_name] = state_dict[up_key]

    state_dict = updated_state_dict

    num_lora = 0
    for key in state_dict:
        if "up." in key:
            continue
        up_key = key.replace(".down.", ".up.")
        model_key = key.replace("processor.", "").replace("_lora", "").replace(
            "down.", "").replace("up.", "").replace(".lora", "")
        model_key = model_key.replace("to_out.", "to_out.0.")
        layer_infos = model_key.split(".")[:-1]

        curr_layer = unet
        while len(layer_infos) > 0:
            temp_name = layer_infos.pop(0)
            curr_layer = curr_layer.__getattr__(temp_name)

        weight_down = state_dict[key].to(
            curr_layer.weight.data.device, curr_layer.weight.data.dtype)
        weight_up = state_dict[up_key].to(
            curr_layer.weight.data.device, curr_layer.weight.data.dtype)

        if weight_up.ndim == 2:
            curr_layer.weight.data += 1/8 * alpha * \
                torch.mm(weight_up, weight_down)
        else:
            assert weight_up.ndim == 4
            curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
                start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
        num_lora += 1

    return unet