import contextlib
import random
import numpy as np
import os
from glob import glob
from PIL import Image, ImageSequence

import torch
from torchvision.io import read_video, write_video
import torchvision.transforms as T

from diffusers import DDIMScheduler, StableDiffusionControlNetPipeline, StableDiffusionPipeline, StableDiffusionDepth2ImgPipeline, ControlNetModel
from .controlnet_utils import CONTROLNET_DICT, control_preprocess
from einops import rearrange

FRAME_EXT = [".jpg", ".png"]


def init_model(device="cuda", sd_version="1.5", model_key=None, control_type="none", weight_dtype="fp16"):

    use_depth = False
    if model_key is None:
        if sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        elif sd_version == 'depth':
            model_key = "stabilityai/stable-diffusion-2-depth"
            use_depth = True
        else:
            raise ValueError(
                f'Stable-diffusion version {sd_version} not supported.')

        print(f'[INFO] loading stable diffusion from: {model_key}')
    else:
        print(f'[INFO] loading custome model from: {model_key}')

    scheduler = DDIMScheduler.from_pretrained(
        model_key, subfolder="scheduler")

    if weight_dtype == "fp16":
        weight_dtype = torch.float16
    else:
        weight_dtype = torch.float32

    if control_type not in ["none", "pnp"]:
        controlnet_key = CONTROLNET_DICT[control_type]
        print(f'[INFO] loading controlnet from: {controlnet_key}')
        controlnet = ControlNetModel.from_pretrained(
            controlnet_key, torch_dtype=weight_dtype)
        print(f'[INFO] loaded controlnet!')
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            model_key, controlnet=controlnet, torch_dtype=weight_dtype
        )
    elif use_depth:
        pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
            model_key, torch_dtype=weight_dtype
        )
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            # model_key, torch_dtype=weight_dtype
            model_key, torch_dtype=weight_dtype,
        )

    return pipe.to(device), scheduler, model_key


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = T.ToTensor()(image)
    return image.unsqueeze(0)


def process_frames(frames, h, w):

    fh, fw = frames.shape[-2:]
    h = int(np.floor(h / 64.0)) * 64
    w = int(np.floor(w / 64.0)) * 64

    nw = int(fw / fh * h)
    if nw >= w:
        size = (h, nw)
    else:
        size = (int(fh / fw * w), w)

    assert len(frames.shape) >= 3
    if len(frames.shape) == 3:
        frames = [frames]

    print(
        f"[INFO] frame size {(fh, fw)} resize to {size} and centercrop to {(h, w)}")

    frame_ls = []
    for frame in frames:
        resized_frame = T.Resize(size)(frame)
        cropped_frame = T.CenterCrop([h, w])(resized_frame)
        # croped_frame = T.FiveCrop([h, w])(resized_frame)[0]
        frame_ls.append(cropped_frame)
    return torch.stack(frame_ls)


def glob_frame_paths(video_path):
    frame_paths = []
    for ext in FRAME_EXT:
        frame_paths += glob(os.path.join(video_path, f"*{ext}"))
    frame_paths = sorted(frame_paths)
    return frame_paths


def load_video(video_path, h, w, frame_ids=None, device="cuda"):
    

    if ".mp4" in video_path:
        frames, _, _ = read_video(
            video_path, output_format="TCHW", pts_unit="sec")
        frames = frames / 255
    elif ".gif" in video_path:
        frames = Image.open(video_path)
        frame_ls = []
        for frame in ImageSequence.Iterator(frames):
            frame_ls += [T.ToTensor()(frame.convert("RGB"))]
        frames = torch.stack(frame_ls)
    else:
        frame_paths = glob_frame_paths(video_path)
        frame_ls = []
        for frame_path in frame_paths:
            frame = load_image(frame_path)
            frame_ls.append(frame)
        frames = torch.cat(frame_ls)
    if frame_ids is not None:
        frames = frames[frame_ids]

    print(f"[INFO] loaded video with {len(frames)} frames from: {video_path}")

    frames = process_frames(frames, h, w)
    return frames.to(device)


def save_video(frames: torch.Tensor, path, frame_ids=None, save_frame=False):
    os.makedirs(path, exist_ok=True)
    if frame_ids is None:
        frame_ids = [i for i in range(len(frames))]
    frames = frames[frame_ids]

    proc_frames = (rearrange(frames, "T C H W -> T H W C") * 255).to(torch.uint8).cpu()
    write_video(os.path.join(path, "output.mp4"), proc_frames, fps = 30, video_codec="h264")
    print(f"[INFO] save video to {os.path.join(path, 'output.mp4')}")

    if save_frame:
        save_frames(frames, os.path.join(path, "frames"), frame_ids = frame_ids)
    

def save_frames(frames: torch.Tensor, path, ext="png", frame_ids=None):
    os.makedirs(path, exist_ok=True)
    if frame_ids is None:
        frame_ids = [i for i in range(len(frames))]
    for i, frame in zip(frame_ids, frames):
        T.ToPILImage()(frame).save(
            os.path.join(path, '{:04}.{}'.format(i, ext)))


def load_latent(latent_path, t, frame_ids=None):
    latent_fname = f'noisy_latents_{t}.pt'

    lp = os.path.join(latent_path, latent_fname)
    assert os.path.exists(
        lp), f"Latent at timestep {t} not found in {latent_path}."

    latents = torch.load(lp)
    if frame_ids is not None:
        latents = latents[frame_ids]
    
    # print(f"[INFO] loaded initial latent from {lp}")

    return latents

@torch.no_grad()
def prepare_depth(pipe, frames, frame_ids, work_dir):
    
    depth_ls = []
    depth_dir = os.path.join(work_dir, "depth")
    os.makedirs(depth_dir, exist_ok=True)
    for frame, frame_id in zip(frames, frame_ids):
        depth_path = os.path.join(depth_dir, "{:04}.pt".format(frame_id))
        depth = load_depth(pipe, depth_path, frame)
        depth_ls += [depth]
    print(f"[INFO] loaded depth images from {depth_path}")
    return torch.cat(depth_ls)

# From pix2video: code/file_utils.py

def load_depth(model, depth_path, input_image, dtype=torch.float32):
    if os.path.exists(depth_path):
        depth_map = torch.load(depth_path)
    else:
        input_image = T.ToPILImage()(input_image.squeeze())
        depth_map = prepare_depth_map(
            model, input_image, dtype=dtype, device=model.device)
        torch.save(depth_map, depth_path)
        depth_image = (((depth_map + 1.0) / 2.0) * 255).to(torch.uint8)
        T.ToPILImage()(depth_image.squeeze()).convert(
            "L").save(depth_path.replace(".pt", ".png"))

    return depth_map

@torch.no_grad()
def prepare_depth_map(model, image, depth_map=None, batch_size=1, do_classifier_free_guidance=False, dtype=torch.float32, device="cuda"):
    if isinstance(image, Image.Image):
        image = [image]
    else:
        image = list(image)

    if isinstance(image[0], Image.Image):
        width, height = image[0].size
    elif isinstance(image[0], np.ndarray):
        width, height = image[0].shape[:-1]
    else:
        height, width = image[0].shape[-2:]

    if depth_map is None:
        pixel_values = model.feature_extractor(
            images=image, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(device=device)
        # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
        # So we use `torch.autocast` here for half precision inference.
        context_manger = torch.autocast(
            "cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
        with context_manger:
            ret = model.depth_estimator(pixel_values)
            depth_map = ret.predicted_depth
            # depth_image = ret.depth
    else:
        depth_map = depth_map.to(device=device, dtype=dtype)

    indices = depth_map != -1
    bg_indices = depth_map == -1
    min_d = depth_map[indices].min()

    if bg_indices.sum() > 0:
        depth_map[bg_indices] = min_d - 10
        # min_d = min_d - 10

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(height // model.vae_scale_factor,
              width // model.vae_scale_factor),
        mode="bicubic",
        align_corners=False,
    )

    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
    depth_map = depth_map.to(dtype)

    # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
    if depth_map.shape[0] < batch_size:
        repeat_by = batch_size // depth_map.shape[0]
        depth_map = depth_map.repeat(repeat_by, 1, 1, 1)

    depth_map = torch.cat(
        [depth_map] * 2) if do_classifier_free_guidance else depth_map
    return depth_map


def get_latents_dir(latents_path, model_key):
    model_key = model_key.split("/")[-1]
    return os.path.join(latents_path, model_key)


def get_controlnet_kwargs(controlnet, x, cond, t, controlnet_cond, controlnet_scale=1.0):
    down_block_res_samples, mid_block_res_sample = controlnet(
        x,
        t,
        encoder_hidden_states=cond,
        controlnet_cond=controlnet_cond,
        return_dict=False,
    )
    down_block_res_samples = [
        down_block_res_sample * controlnet_scale
        for down_block_res_sample in down_block_res_samples
    ]
    mid_block_res_sample *= controlnet_scale
    controlnet_kwargs = {"down_block_additional_residuals": down_block_res_samples,
                         "mid_block_additional_residual": mid_block_res_sample}
    return controlnet_kwargs


def get_frame_ids(frame_range, frame_ids=None):
    if frame_ids is None:
        frame_ids = list(range(*frame_range))
    frame_ids = sorted(frame_ids)

    if len(frame_ids) > 4:
        frame_ids_str = "{} {} ... {} {}".format(
            *frame_ids[:2], *frame_ids[-2:])
    else:
        frame_ids_str = " ".join(["{}"] * len(frame_ids)).format(*frame_ids)
    print("[INFO] frame indexes: ", frame_ids_str)
    return frame_ids


def prepare_control(control, frames, frame_ids, save_path):
    if control not in CONTROLNET_DICT.keys():
        print(f"[WARNING] unknown controlnet type {control}")
        return None

    control_subdir = f'{save_path}/{control}_image'

    preprocess_flag = True
    if os.path.exists(control_subdir):
        print(f"[INFO] load control image from {control_subdir}.")
        control_image_ls = []
        for frame_id in frame_ids:
            image_path = os.path.join(
                control_subdir, "{:04}.png".format(frame_id))
            if not os.path.exists(image_path):
                break
            control_image_ls += [load_image(image_path)]
        else:
            preprocess_flag = False
            control_images = torch.cat(control_image_ls)

    if preprocess_flag:
        print("[INFO] preprocessing control images...")
        control_images = control_preprocess(frames, control)
        print(f"[INFO] save control images to {control_subdir}.")
        os.makedirs(control_subdir, exist_ok=True)
        for image, frame_id in zip(control_images, frame_ids):
            image_path = os.path.join(
                control_subdir, "{:04}.png".format(frame_id))
            T.ToPILImage()(image).save(image_path)

    return control_images
    
def isinstance_str(x: object, cls_name: str):
    """
    Checks whether x has any class *named* cls_name in its ancestry.
    Doesn't require access to the class's implementation.
    
    Useful for patching!
    """

    for _cls in x.__class__.__mro__:
        if _cls.__name__ == cls_name:
            return True
    
    return False

def init_generator(device: torch.device, fallback: torch.Generator=None):
    """
    Forks the current default random generator given device.
    """
    if device.type == "cpu":
        return torch.Generator(device="cpu").set_state(torch.get_rng_state())
    elif device.type == "cuda":
        return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
    else:
        if fallback is None:
            return init_generator(torch.device("cpu"))
        else:
            return fallback

def join_frame(x, fsize):
    """ Join multi-frame tokens """
    x = rearrange(x, "(B F) N C -> B (F N) C", F=fsize)
    return x

def split_frame(x, fsize):
    """ Split multi-frame tokens """
    x = rearrange(x, "B (F N) C -> (B F) N C", F=fsize)
    return x

def func_warper(funcs):
    """ Warp a function sequence """
    def fn(x, **kwarg):
        for func in funcs:
            x = func(x, **kwarg)
        return x
    return fn

def join_warper(fsize):
    def fn(x):
        x = join_frame(x, fsize)
        return x
    return fn

def split_warper(fsize):
    def fn(x):
        x = split_frame(x, fsize)
        return x
    return fn